コード例 #1
0
def test(test_loader, model, configs):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')

    # switch to evaluate mode
    model.train() # if use model.val(), the performance become worse
    with torch.no_grad():
        start_time = time.time()
        for batch_idx, (origin_imgs, resized_imgs, org_ball_pos_xy, global_ball_pos_xy, event_class, target_seg) in enumerate(
            tqdm(test_loader)):
            data_time.update(time.time() - start_time)
            batch_size = resized_imgs.size(0)
            target_seg = target_seg.to(configs.device, non_blocking=True)
            resized_imgs = resized_imgs.to(configs.device, non_blocking=True).float()
            # compute output
            if 'local' in configs.tasks:
                origin_imgs = origin_imgs.to(configs.device, non_blocking=True).float()
                pred_ball_global, pred_ball_local, pred_events, pred_seg, local_ball_pos_xy, total_loss, _ = model(origin_imgs,
                    resized_imgs, org_ball_pos_xy, global_ball_pos_xy, event_class, target_seg)
            else:
                pred_ball_global, pred_ball_local, pred_events, pred_seg, local_ball_pos_xy, total_loss, _ = model(None,
                    resized_imgs, org_ball_pos_xy, global_ball_pos_xy, event_class, target_seg)
            print('total_loss: {}'.format(total_loss.item()))
            # Transfer output to cpu
            pred_ball_global = pred_ball_global.cpu().numpy()
            global_ball_pos_xy = global_ball_pos_xy.numpy()
            if pred_ball_local is not None:
                pred_ball_local = pred_ball_local.cpu().numpy()
                local_ball_pos_xy = local_ball_pos_xy.cpu().numpy()  # Ground truth of the local stage
            if pred_events is not None:
                pred_events = pred_events.cpu().numpy()
            if pred_seg is not None:
                pred_seg = pred_seg.cpu().numpy()
                target_seg = target_seg.cpu().numpy()

            org_ball_pos_xy = org_ball_pos_xy.numpy()

            seg_thresh = 0.5
            event_thresh = 0.5
            events_idx_to_names = {
                0: 'bounce',
                1: 'net',
                2: 'empty'
            }
            fig, axes = plt.subplots(nrows=batch_size, ncols=2, figsize=(10, 5))
            plt.tight_layout()
            axes.ravel()
            saved_dir = '../../docs/test_output_full'
            if not os.path.isdir(saved_dir):
                os.makedirs(saved_dir)
            for sample_idx in range(batch_size):
                w, h = configs.input_size
                # Get target
                sample_org_ball_pos_xy = org_ball_pos_xy[sample_idx]
                sample_global_ball_pos_xy = global_ball_pos_xy[sample_idx]  # Target
                # Process the global stage
                sample_pred_ball_global = pred_ball_global[sample_idx]
                sample_pred_ball_global[sample_pred_ball_global < configs.thresh_ball_pos_mask] = 0.
                sample_pred_ball_global_x = np.argmax(sample_pred_ball_global[:w])
                sample_pred_ball_global_y = np.argmax(sample_pred_ball_global[w:])
                print('Global stage: (x, y) - org: ({}, {}), gt = ({}, {}), prediction = ({}, {})'.format(
                    sample_org_ball_pos_xy[0], sample_org_ball_pos_xy[1],
                    sample_global_ball_pos_xy[0], sample_global_ball_pos_xy[1], sample_pred_ball_global_x,
                    sample_pred_ball_global_y))

                # Process event stage
                if pred_events is not None:
                    sample_target_event = event_class[sample_idx].item()
                    sample_pred_event = (pred_events[sample_idx] > event_thresh).astype(np.int)
                    print('Event stage: gt = {}, prediction: {}'.format(sample_target_event, pred_events[sample_idx]))

                if pred_seg is not None:
                    sample_target_seg = target_seg[sample_idx].transpose(1, 2, 0)
                    sample_pred_seg = pred_seg[sample_idx].transpose(1, 2, 0)
                    print('Segmentation: Shape sample_target_seg: {}, sample_pred_seg: {}'.format(
                        sample_target_seg.shape, sample_pred_seg.shape))
                    print('Segmentation: Max values sample_target_seg: {}, sample_pred_seg: {}'.format(
                        sample_target_seg.max(), sample_pred_seg.max()))

                    print('Before cast Segmentation sample_target_seg R: {}, G: {}, B: {}'.format(sample_target_seg[:, :, 0].sum(),
                                                                    sample_target_seg[:, :, 1].sum(),
                                                                    sample_target_seg[:, :, 2].sum()))
                    print('Before cast Segmentation sample_pred_seg R: {}, G: {}, B: {}'.format(
                        sample_pred_seg[:, :, 0].sum(),
                        sample_pred_seg[:, :, 1].sum(),
                        sample_pred_seg[:, :, 2].sum()))
                    sample_target_seg = sample_target_seg.astype(np.int)
                    sample_pred_seg = (sample_pred_seg > seg_thresh).astype(np.int)
                    print('After Segmentation sample_target_seg R: {}, G: {}, B: {}'.format(sample_target_seg[:, :, 0].sum(),
                                                                    sample_target_seg[:, :, 1].sum(),
                                                                    sample_target_seg[:, :, 2].sum()))
                    print('After Segmentation sample_pred_seg R: {}, G: {}, B: {}'.format(
                        sample_pred_seg[:, :, 0].sum(),
                        sample_pred_seg[:, :, 1].sum(),
                        sample_pred_seg[:, :, 2].sum()))
                    axes[2 * sample_idx].imshow(sample_target_seg  * 255)
                    axes[2 * sample_idx + 1].imshow(sample_pred_seg  * 255)
                    # title
                    target_title = 'target seg'
                    pred_title = 'pred seg'
                    if pred_events is not None:
                        target_title += ', event: {}'.format(events_idx_to_names[sample_target_event])
                        pred_title += ', is bounce: {}, is net: {}'.format(sample_pred_event[0], sample_pred_event[1])

                    axes[2 * sample_idx].set_title(target_title)
                    axes[2 * sample_idx + 1].set_title(pred_title)


                    plt.savefig(
                        os.path.join(saved_dir, 'batch_idx_{}_sample_idx_{}.jpg'.format(batch_idx, sample_idx)))

            batch_time.update(time.time() - start_time)

            start_time = time.time()
    print('Done testing')
コード例 #2
0
def eval(_run, _log):
    cfg = edict(_run.config)

    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    random.seed(cfg.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = torch.device("cuda:0")
    device1 = torch.device("cuda:1")

    checkpoint_dir = os.path.join('experiments/predict', str(_run._id),
                                  'checkpoints')
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    angle_net = AngleNet(cfg.model)
    contact_net = ContactNet(cfg.model)
    seg_net = SegNet(cfg.model)

    if not cfg.resume_angle == 'None':
        model_dict = torch.load(cfg.resume_angle)
        angle_net.load_state_dict(model_dict)
    if not cfg.resume_contact == 'None':
        model_dict = torch.load(cfg.resume_contact)
        contact_net.load_state_dict(model_dict)
    if not cfg.resume_seg == 'None':
        model_dict = torch.load(cfg.resume_seg)
        seg_net.load_state_dict(model_dict)
    # load nets into gpu
    if cfg.num_gpus > 1 and torch.cuda.is_available():
        angle_net = torch.nn.DataParallel(angle_net)
        contact_net = torch.nn.DataParallel(contact_net)
        seg_net = torch.nn.DataParallel(seg_net)
    angle_net.to(device)
    contact_net.to(device1)
    seg_net.to(device)
    if cfg.input_method == "planercnn":
        val_dataset = PlaneDataset(cfg.dataset,
                                   split='test',
                                   random=False,
                                   evaluation=True)
    elif cfg.input_method == "planeae":
        val_dataset = PlaneDatasetAE(cfg.dataset,
                                     split='test',
                                     random=False,
                                     evaluation=True)
    else:
        print('input method ' + cfg.input_method + ' not supported!')
        exit()
    val_loader = data.DataLoader(val_dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=cfg.dataset.num_workers)

    use_gt_relation = False
    write_relation = False
    assess_seg = False
    depth_only = True

    angle_net.eval()
    contact_net.eval()
    seg_net.eval()
    angle_accuracies = AverageMeter()
    contact_accuracies = AverageMeter()
    contact_ious = AverageMeter()
    stat_parallel = np.zeros(4)  # classification precision recall
    stat_ortho = np.zeros(4)
    stat_contact = np.zeros(4)
    normal_errors = [AverageMeter(), AverageMeter(), AverageMeter()]
    normal_diff_errors = [AverageMeter(), AverageMeter(), AverageMeter()]
    depth_errors = [AverageMeter(), AverageMeter(), AverageMeter()]
    offset_errors = [AverageMeter(), AverageMeter(), AverageMeter()]
    contact_depth_errors = [AverageMeter(), AverageMeter(), AverageMeter()]

    with torch.no_grad():
        for iter, sample in enumerate(val_loader):
            if iter == 100:
                break
            sceneIndex = sample["sceneIndex"].item()
            imageIndex = sample["imageIndex"].item()

            save_path = os.path.join('experiments/predict', str(
                _run._id)) + f'/results/{iter}/'
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            image = sample["image"][0]  #3x224x224
            #pd_points = sample["pd_points"][0] #3x224x224
            tg_planes = sample["tg_planes"][0]
            pd_planes = sample["pd_planes"][0]  #planenumx4
            matched_single = sample["matched_single"][0]
            pd_masks_small = sample["pd_masks_small"][0]  #planenumx224x224
            pd_masks_small_c = sample["pd_masks_small_c"][0]
            segment_depths_small = sample["segment_depths_small"][0]
            gt_angle = sample["gt_angle"][0]
            gt_contact = sample["gt_contact"][0]
            gt_contactline = sample["gt_contactline"][0]
            matched_pair = sample["matched_pair"][0]
            matched_contact = sample["matched_contact"][0]
            planepair_index = sample["planepair_index"][0]
            ori_image = sample["ori_image"][0]
            camera = sample["camera"][0]
            ori_pd_points = sample["ori_pd_points"][0]
            pd_masks = sample["pd_masks"][0]  # planenumx480x640
            ransac_masks = sample["ransac_masks"][0]
            ransac_planes = sample["ransac_planes"][0]
            sensor_depth = sample["sensor_depth"][0]
            tg_masks = sample["tg_masks"][0]

            image = image.to(device)
            #pd_points = pd_points.to(device)
            pd_masks_small = pd_masks_small.to(device)
            pd_masks_small_c = pd_masks_small_c.to(device)
            segment_depths_small = segment_depths_small.to(device)
            gt_angle = gt_angle.to(device)
            gt_contact = gt_contact.to(device1)

            # build the input to network
            input_tensor, seg_tensor = [], []
            planepair_index = planepair_index.numpy()
            planepair_num = planepair_index.shape[0]
            pd_planes = pd_planes.numpy()
            tg_planes = tg_planes.numpy()
            ori_image = ori_image.numpy().astype(np.uint8)
            ori_magnitude = get_magnitude(ori_image)
            matched_single = matched_single.numpy().astype(np.bool)
            matched_pair = matched_pair.numpy().astype(np.bool)
            matched_contact = matched_contact.numpy().astype(np.bool)
            for ppindex in planepair_index:
                p, q = ppindex
                pm = pd_masks_small[p:p + 1]
                qm = pd_masks_small[q:q + 1]
                pdepth = segment_depths_small[p:p + 1]
                qdepth = segment_depths_small[q:q + 1]
                dot_map = torch.full_like(pm,
                                          np.abs(
                                              np.dot(pd_planes[p, :3],
                                                     pd_planes[q, :3])),
                                          device=device)
                if cfg.model.input_channel == 8:
                    input_tensor.append(
                        torch.cat((image, pm, qm, pdepth, qdepth, dot_map),
                                  dim=0))
                elif cfg.model.input_channel == 7:
                    input_tensor.append(
                        torch.cat((image, pm, qm, pdepth, qdepth), dim=0))
                elif cfg.model.input_channel == 5:
                    input_tensor.append(torch.cat((image, pm, qm), dim=0))
                else:
                    input_tensor.append(torch.cat((pm, qm), dim=0))
            input_tensor = torch.stack(input_tensor)
            print(iter, input_tensor.size())

            # inference
            try:
                angle_prob = angle_net(input_tensor)
            except:
                half_num = int(planepair_num / 2.0 + 0.5)
                angle_prob0 = angle_net(input_tensor[0:half_num, :, :, :])
                angle_prob1 = angle_net(
                    input_tensor[half_num:planepair_num, :, :, :])
                angle_prob = torch.cat((angle_prob0, angle_prob1), dim=0)
            input_tensor = input_tensor.to(device1)
            try:
                contact_prob, contactline_prob = contact_net(input_tensor)
            except:
                half_num = int(planepair_num / 2.0 + 0.5)
                contact_prob0, contactline_prob0 = contact_net(
                    input_tensor[0:half_num, :, :, :])
                contact_prob1, contactline_prob1 = contact_net(
                    input_tensor[half_num:planepair_num, :, :, :])
                contact_prob = torch.cat((contact_prob0, contact_prob1), dim=0)
                contactline_prob = torch.cat(
                    (contactline_prob0, contactline_prob1), dim=0)
            del input_tensor

            # relation assessment
            acc_angle, pred_angle = accuracy(angle_prob, gt_angle, angle=True)
            acc_contact, pred_contact = accuracy(contact_prob,
                                                 gt_contact,
                                                 angle=False)
            matched_pair_num = matched_pair.sum()
            if matched_pair_num > 0:
                acc_angle, acc_contact = np.sum(
                    acc_angle[matched_pair]) * 100 / matched_pair_num, np.sum(
                        acc_contact[matched_pair]) * 100 / matched_pair_num
                angle_accuracies.update(acc_angle, matched_pair_num)
                contact_accuracies.update(acc_contact, matched_pair_num)

            camera = camera.numpy()
            ranges2d = get_ranges2d(camera)
            contactline_prob = contactline_prob.cpu().numpy().squeeze()
            pd_masks = pd_masks.numpy()
            gt_contactline = gt_contactline.numpy()
            gt_angle = gt_angle.cpu().numpy()
            gt_contact = gt_contact.cpu().numpy()

            ## precision and recall
            if matched_pair_num > 0:
                #pred_angle, pred_contact = eval_relation_baseline(planepair_index, pd_planes, pd_masks)
                stat_parallel += comp_precision_recall(
                    pred_angle[matched_pair] == 1, gt_angle[matched_pair] == 1)
                stat_ortho += comp_precision_recall(
                    pred_angle[matched_pair] == 0, gt_angle[matched_pair] == 0)
                stat_contact += comp_precision_recall(
                    pred_contact[matched_contact] == 1,
                    gt_contact[matched_contact] == 1)
                iou_flag = (gt_contact == 1) & matched_contact
                if np.sum(iou_flag) > 0:
                    pred_iou = comp_conrel_iou(gt_contact, gt_contactline,
                                               contactline_prob)
                    contact_ious.update(np.mean(pred_iou[iou_flag]),
                                        np.sum(iou_flag))

            contact_list = []
            pair_areas = np.zeros((planepair_num, 1))
            contact_line_probs = []
            for i, ppindex in enumerate(planepair_index):
                p, q = ppindex
                pm = pd_masks[p]
                qm = pd_masks[q]
                #pair_areas[i] = pm.sum() + qm.sum()
                pair_areas[i] = pm.sum() / 640 * qm.sum() / 480
                if write_relation:
                    tmp_img = ori_image.copy()
                    tmp_img[pm > 0.5, 0] = 255
                    tmp_img[qm > 0.5, 2] = 255
                    if (gt_angle[i]
                            if use_gt_relation else pred_angle[i]) == 1:
                        cv2.imwrite(f'{save_path}para_{i}.png', tmp_img)
                        if (gt_contact[i]
                                if use_gt_relation else pred_contact[i]) == 1:
                            cv2.imwrite(f'{save_path}coplane_{i}.png', tmp_img)
                    elif (gt_angle[i]
                          if use_gt_relation else pred_angle[i]) == 0:
                        cv2.imwrite(f'{save_path}ortho_{i}.png', tmp_img)

                if (gt_contact[i]
                        if use_gt_relation else pred_contact[i]) == 0:
                    continue
                gt_mask = gt_contactline[i]
                re_mask = cv2.resize(contactline_prob[i], dsize=(640, 480))
                if use_gt_relation:
                    re_mask = gt_mask

                contact_line_probs.append(re_mask)
                mask_thres = 0.5
                if pred_angle[i] == 1:
                    mask_thres = 0.25
                ylist, xlist = extract_line2d(re_mask, mask_thres)
                raydirs = ranges2d[ylist, xlist, :]

                contact_list.append([p, q, raydirs, i])
                if write_relation:
                    black_img = np.zeros((480, 640, 3), dtype=np.uint8)
                    black_img[:, :, 0] = re_mask * 255
                    black_img[:, :, 1] = gt_mask * 255
                    black_img[re_mask > mask_thres, 2] = 255
                    tmp_img[re_mask > 0.25, 1] = 255
                    cv2.imwrite(f'{save_path}contact_{i}.png',
                                np.concatenate([tmp_img, black_img], 1))

            contact_line_probs = np.asarray(contact_line_probs)
            # optimization
            if use_gt_relation:
                flag_para = (matched_pair & (gt_angle == 1))
                flag_ortho = (matched_pair & (gt_angle == 0))
                flag_contact = (matched_contact & (gt_contact == 1))
                para_list = planepair_index[flag_para, :]
                ortho_list = planepair_index[flag_ortho, :]
                para_weight = pair_areas[flag_para, :]
                ortho_weight = pair_areas[flag_ortho, :]
                contact_weight = pair_areas[flag_contact, :]
                coplane_list = planepair_index[flag_para & flag_contact, :]
                coplane_weight = pair_areas[flag_para & flag_contact, :]
            else:
                flag_para, flag_ortho, flag_contact = pred_angle == 1, pred_angle == 0, pred_contact == 1
                para_list = planepair_index[flag_para, :]
                para_weight = pair_areas[flag_para, :]
                ortho_list = planepair_index[flag_ortho, :]
                ortho_weight = pair_areas[flag_ortho, :]
                contact_weight = pair_areas[flag_contact, :]
                coplane_list = planepair_index[flag_para & flag_contact, :]
                coplane_weight = pair_areas[flag_para & flag_contact, :]
            para_weight /= np.sum(para_weight)
            ortho_weight /= np.sum(ortho_weight)
            contact_weight /= np.sum(contact_weight)
            coplane_weight /= np.sum(coplane_weight)

            ori_pd_points = ori_pd_points.numpy()
            point_list = []
            for i in range(pd_masks.shape[0]):
                point_list.append(ori_pd_points[pd_masks[i] > 0.5, :])
            # ---------------------------------
            # -- solve the plane parameters by optimization
            cv2.imwrite(f"{save_path}image.jpg", ori_image)
            sensor_depth = sensor_depth.numpy()
            visualize(ori_image,
                      sensor_depth, [],
                      camera,
                      'point_sensor',
                      save_path,
                      depthonly=depth_only,
                      sensor=True)
            ransac_masks = ransac_masks.numpy()
            ransac_planes = ransac_planes.numpy()

            p_depth_gt = visualize(ori_image,
                                   ransac_masks,
                                   ransac_planes,
                                   camera,
                                   'point_gt',
                                   save_path,
                                   depthonly=depth_only)

            seg_gt = blend_image_mask(ori_image, ransac_masks, thres=0.5)
            cv2.imwrite(f"{save_path}seg_gt.png", seg_gt)

            p_depth_planercnn = visualize(ori_image,
                                          pd_masks,
                                          pd_planes,
                                          camera,
                                          'point_planercnn',
                                          save_path,
                                          depthonly=depth_only)

            seg_planercnn = blend_image_mask(ori_image, pd_masks, thres=0.5)
            cv2.imwrite(f"{save_path}seg_planercnn.png", seg_planercnn)
            alpha = np.array([1., 0., 10., 1., 1., 0., 0.])
            re_planes_angle = plane_minimize(pd_planes, point_list, para_list,
                                             para_weight, ortho_list,
                                             ortho_weight, contact_list,
                                             contact_weight, coplane_list,
                                             coplane_weight, alpha)
            # p_depth_angle = visualize(ori_image, pd_masks, re_planes_angle, camera, 'point_result_angle', save_path, depthonly=depth_only)

            #alpha = np.array([1.,0.,10.,1.,1.,10.,0.])# ae
            alpha = np.array([1., 10., 10., 1., 1., 1., 0.])
            re_planes = plane_minimize(pd_planes, point_list, para_list,
                                       para_weight, ortho_list, ortho_weight,
                                       contact_list, contact_weight,
                                       coplane_list, coplane_weight, alpha)
            #p_depth_contact = visualize(ori_image, pd_masks, re_planes, camera, 'point_result', save_path, depthonly=depth_only)

            # ---------------------------------
            # -- split the contact using optimized 3d parameters
            contact_split, line_equs, line_flags, along_line_mask = expand_masks1(
                re_planes, contact_list, contact_line_probs, ranges2d,
                pd_masks)
            seg_contact = blend_image_mask(ori_image, contact_split)
            cv2.imwrite(f"{save_path}seg_contact.png", seg_contact)
            cv2.imwrite(f"{save_path}seg_contact_line.png",
                        along_line_mask.astype(np.uint8) * 255)
            #cv2.imwrite(f"{save_path}seg_expanded_line.png", expanded_seg*0.7+line_img*0.3)

            # refine segmentation by network and contact
            image = image.repeat(pd_masks_small.size(0), 1, 1, 1)
            contact_split_small = np.zeros((contact_split.shape[0], 224, 224),
                                           dtype=np.float32)
            for i in range(contact_split.shape[0]):
                contact_split_small[i] = cv2.resize(contact_split[i],
                                                    dsize=(224, 224))
            contact_split_small = torch.cuda.FloatTensor(contact_split_small)
            input_tensor = torch.cat([
                image,
                pd_masks_small.unsqueeze(1),
                contact_split_small.unsqueeze(1)
            ],
                                     dim=1)
            seg_prob_small = seg_net(input_tensor)
            del input_tensor, pd_masks_small, contact_split_small
            seg_prob_small = seg_prob_small.cpu().numpy().squeeze()
            seg_prob = np.zeros((seg_prob_small.shape[0], 480, 640))
            for i, m in enumerate(seg_prob_small):
                seg_prob[i] = cv2.resize(m, dsize=(640, 480))
            seg_prob = clean_prob_mask(seg_prob)
            seg_refined = blend_image_mask(ori_image, seg_prob, thres=0.5)
            cv2.imwrite(f"{save_path}seg_refined.png", seg_refined)
            p_depth_all = visualize(ori_image,
                                    seg_prob,
                                    re_planes,
                                    camera,
                                    'point_result_ex',
                                    save_path,
                                    depthonly=depth_only)

            # --------------------------------
            # -- do the evaluation

            # 1. evaluate depth
            tg_masks = tg_masks.numpy()
            comp_depth_error(ori_image, camera, tg_masks[matched_single],
                             tg_planes[matched_single],
                             pd_planes[matched_single],
                             re_planes_angle[matched_single],
                             re_planes[matched_single], depth_errors,
                             save_path)
            # 2. evalute normal
            comp_parameter_error(tg_planes[matched_single],
                                 pd_planes[matched_single],
                                 re_planes_angle[matched_single],
                                 re_planes[matched_single],
                                 tg_masks[matched_single], normal_errors,
                                 offset_errors)
            # 3. contact depth consistency
            flag_contact = (gt_contact == 1)
            comp_contact_error(gt_contactline[flag_contact],
                               planepair_index[flag_contact], ranges2d,
                               pd_planes, re_planes_angle, re_planes,
                               contact_depth_errors)
            #comp_contact_error(gt_contactline, planepair_index, ranges2d, pd_planes, re_planes_angle, re_planes, contact_depth_errors, gt_contact, gt_angle, tg_planes, pd_masks)

            ## sensor depth
            semantic_gt = ransac_masks.max(0)
            semantic_pd = pd_masks.max(0)
            semantic_re = seg_prob.max(0)

            p_depth_gt[semantic_gt < 0.5] = 0.
            p_depth_planercnn[semantic_pd < 0.5] = 0.
            p_depth_all[semantic_re < 0.5] = 0.
            cv2.imwrite(f"{save_path}depth_sensor.png",
                        drawDepthImage(sensor_depth))
            cv2.imwrite(f"{save_path}depth_gt.png", drawDepthImage(p_depth_gt))
            cv2.imwrite(f"{save_path}depth_prcnn.png",
                        drawDepthImage(p_depth_planercnn))
            cv2.imwrite(f"{save_path}depth_all.png",
                        drawDepthImage(p_depth_all))

            # evaluate pairwise angular difference
            diff_flag = matched_pair  #&(gt_angle!=2)
            if np.sum(diff_flag) > 0:
                diff_areas = pair_areas[diff_flag, :].reshape(-1)
                normal_diff_errors[0].update(
                    eval_planepair_diff(pd_planes, tg_planes,
                                        planepair_index[diff_flag, :],
                                        diff_areas), np.sum(diff_areas))
                normal_diff_errors[1].update(
                    eval_planepair_diff(re_planes_angle, tg_planes,
                                        planepair_index[diff_flag, :],
                                        diff_areas), np.sum(diff_areas))
                normal_diff_errors[2].update(
                    eval_planepair_diff(re_planes, tg_planes,
                                        planepair_index[diff_flag, :],
                                        diff_areas), np.sum(diff_areas))

        print("-----------geometry accuracy--------------")
        print(
            f'normal error: planercnn {normal_errors[0].avg}, angle {normal_errors[1].avg}, all {normal_errors[2].avg}'
        )
        print(
            f'offset error: planercnn {offset_errors[0].avg}, angle {offset_errors[1].avg}, all {offset_errors[2].avg}'
        )
        print(
            f'depth error: planercnn {depth_errors[0].avg}, angle {depth_errors[1].avg}, all {depth_errors[2].avg}'
        )
        print(
            f'contact depth error: planercnn {contact_depth_errors[0].avg}, angle {contact_depth_errors[1].avg}, all {contact_depth_errors[2].avg}'
        )
        print(
            f'angular diff error: planercnn {normal_diff_errors[0].avg}, angle {normal_diff_errors[1].avg}, all {normal_diff_errors[2].avg}'
        )

        print('\n---------relation classificaiton----------')
        print(angle_accuracies.avg, contact_accuracies.avg)
        precision, recall = stat_parallel[0] / stat_parallel[1], stat_parallel[
            2] / stat_parallel[3]
        print(
            f'parallel precision: {precision} recall: {recall} f1score: {2*(recall*precision)/(recall+precision)}'
        )
        precision, recall = stat_ortho[0] / stat_ortho[1], stat_ortho[
            2] / stat_ortho[3]
        print(
            f'ortho precision: {precision} recall: {recall} f1score: {2*(recall*precision)/(recall+precision)}'
        )
        precision, recall = stat_contact[0] / stat_contact[1], stat_contact[
            2] / stat_contact[3]
        print(
            f'contact precision: {precision} recall: {recall} f1score: {2*(recall*precision)/(recall+precision)}'
        )
        print(f'contact iou: {contact_ious.avg}')
コード例 #3
0
def validate(val_loader, net, criterion, optim, epoch,
             calc_metrics=True,
             dump_assets=False,
             dump_all_images=False):
    """
    Run validation for one epoch

    :val_loader: data loader for validation
    :net: the network
    :criterion: loss fn
    :optimizer: optimizer
    :epoch: current epoch
    :calc_metrics: calculate validation score
    :dump_assets: dump attention prediction(s) images
    :dump_all_images: dump all images, not just N
    """
    dumper = ImageDumper(val_len=len(val_loader),
                         dump_all_images=dump_all_images,
                         dump_assets=dump_assets,
                         dump_for_auto_labelling=args.dump_for_auto_labelling,
                         dump_for_submission=args.dump_for_submission)

    net.eval()
    val_loss = AverageMeter()
    iou_acc = 0

    for val_idx, data in enumerate(val_loader):
        input_images, labels, img_names, _ = data 
        if args.dump_for_auto_labelling or args.dump_for_submission:
            submit_fn = '{}.png'.format(img_names[0])
            if val_idx % 20 == 0:
                logx.msg(f'validating[Iter: {val_idx + 1} / {len(val_loader)}]')
            if os.path.exists(os.path.join(dumper.save_dir, submit_fn)):
                continue

        # Run network
        assets, _iou_acc = \
            eval_minibatch(data, net, criterion, val_loss, calc_metrics,
                          args, val_idx)

        iou_acc += _iou_acc

        input_images, labels, img_names, _ = data

        dumper.dump({'gt_images': labels,
                     'input_images': input_images,
                     'img_names': img_names,
                     'assets': assets}, val_idx)

        if val_idx > 5 and args.test_mode:
            break

        if val_idx % 20 == 0:
            logx.msg(f'validating[Iter: {val_idx + 1} / {len(val_loader)}]')

    was_best = False
    if calc_metrics:
        was_best = eval_metrics(iou_acc, args, net, optim, val_loss, epoch)

    # Write out a summary html page and tensorboard image table
    if not args.dump_for_auto_labelling and not args.dump_for_submission:
        dumper.write_summaries(was_best)
コード例 #4
0
def train_iae(trainloader, model, class_name, testloader, y_train, device,
              args):
    """
    model train function.
    :param trainloader:
    :param model:
    :param class_name:
    :param testloader:
    :param y_train: numpy array, sample normal/abnormal labels, [1 1 1 1 0 0] like, original sample size.
    :param device: cpu or gpu:0/1/...
    :param args:
    :return:
    """
    global_step = 0
    losses = AverageMeter()
    l2_losses = AverageMeter()
    svdd_losses = AverageMeter()

    start_time = time.time()
    epoch_time = AverageMeter()

    svdd_loss = torch.tensor(0, device=device)
    R = torch.tensor(0, device=device)
    c = torch.randn(256, device=device)

    for epoch in range(1, args.epochs + 1):
        model.train()

        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)
        print('{:3d}/{:3d} ----- {:s} {:s}'.format(epoch, args.epochs,
                                                   time_string(), need_time))

        mse = nn.MSELoss(reduction='mean')  # default

        lr = 0.1 / pow(2, np.floor(epoch / args.lr_schedule))
        logger.add_scalar(class_name + "/lr", lr, epoch)

        if args.optimizer == 'sgd':
            optimizer = optim.SGD(model.parameters(),
                                  lr=lr,
                                  weight_decay=args.weight_decay)
        elif args.optimizer == 'adam':
            optimizer = optim.Adam(model.parameters(),
                                   eps=1e-7,
                                   weight_decay=args.weight_decay)
        else:
            print('not implemented.')

        for batch_idx, (input, _, _) in enumerate(trainloader):
            optimizer.zero_grad()
            input = input.to(device)

            reps, output = model(input)

            if epoch > args.pretrain_epochs:
                dist = torch.sum((reps - c)**2, dim=1)
                scores = dist - R**2
                svdd_loss = args.para_lambda * (
                    R**2 + (1 / args.para_nu) *
                    torch.mean(torch.max(torch.zeros_like(scores), scores)))

            l2_loss = mse(input, output)

            loss = l2_loss + svdd_loss

            l2_losses.update(l2_loss.item(), 1)
            svdd_losses.update(svdd_loss.item(), 1)
            losses.update(loss.item(), 1)

            logger.add_scalar(class_name + '/l2_loss', l2_losses.avg,
                              global_step)
            logger.add_scalar(class_name + '/svdd_loss', svdd_losses.avg,
                              global_step)
            logger.add_scalar(class_name + '/loss', losses.avg, global_step)

            logger.add_scalar(class_name + '/R', R.data, global_step)

            global_step = global_step + 1
            loss.backward()
            optimizer.step()

            # Update hypersphere radius R on mini-batch distances
            if epoch > args.pretrain_epochs:
                R.data = torch.tensor(get_radius(dist, args.para_nu),
                                      device=device)

        # print losses
        print('Epoch: [{} | {}], loss: {:.4f}'.format(epoch, args.epochs,
                                                      losses.avg))

        # log images
        if epoch % args.log_img_steps == 0:
            os.makedirs(os.path.join(RESULTS_DIR, class_name), exist_ok=True)
            fpath = os.path.join(RESULTS_DIR, class_name,
                                 'pretrain_epoch_' + str(epoch) + '.png')
            visualize(input, output, fpath, num=32)

        # test while training
        if epoch % args.log_auc_steps == 0:
            rep, losses_result = test(testloader, model, class_name, args,
                                      device, epoch)

            centroid = torch.mean(rep, dim=0, keepdim=True)

            losses_result = losses_result - losses_result.min()
            losses_result = losses_result / (1e-8 + losses_result.max())
            scores = 1 - losses_result
            auroc_rec = roc_auc_score(y_train, scores)

            _, p = dec_loss_fun(rep, centroid)
            score_p = p[:, 0]
            auroc_dec = roc_auc_score(y_train, score_p)

            print("Epoch: [{} | {}], auroc_rec: {:.4f}; auroc_dec: {:.4f}".
                  format(epoch, args.epochs, auroc_rec, auroc_dec))

            logger.add_scalar(class_name + '/auroc_rec', auroc_rec, epoch)
            logger.add_scalar(class_name + '/auroc_dec', auroc_dec, epoch)

        # initial centroid c before pretrain finished
        if epoch == args.pretrain_epochs:
            rep, losses_result = test(testloader, model, class_name, args,
                                      device, epoch)
            c = update_center_c(rep)

        # time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
コード例 #5
0
def train(train_loader, model, criterion, optimizer):
    '''
    模型训练
    :param train_loader:
    :param model:
    :param criterion:
    :param optimizer:
    :return:
    '''
    # 定义保存更新变量
    data_time = AverageMeter()
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    end = time.time()

    #################
    # train the model
    #################
    model.train()

    # 训练每批数据,然后进行模型的训练
    ## 定义bar 变量
    bar = Bar('Processing', max=len(train_loader))
    for batch_index, (inputs, targets) in enumerate(train_loader):
        data_time.update(time.time() - end)
        # move tensors to GPU if cuda is_available
        inputs, targets = inputs.to(device), targets.to(device)
        # 在进行反向传播之前,我们使用zero_grad方法清空梯度
        optimizer.zero_grad()
        # 模型的预测
        outputs = model(inputs)
        # 计算loss
        loss = criterion(outputs, targets)
        # backward pass:
        loss.backward()
        # perform as single optimization step (parameter update)
        optimizer.step()

        # 计算acc和变量更新
        prec1, _ = accuracy(outputs.data, targets.data, topk=(1, 1))
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        ## 把主要的参数打包放进bar中
        # plot progress
        bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f}'.format(
            batch=batch_index + 1,
            size=len(train_loader),
            data=data_time.val,
            bt=batch_time.val,
            total=bar.elapsed_td,
            eta=bar.eta_td,
            loss=losses.avg,
            top1=top1.avg)
        bar.next()
    bar.finish()
    return (losses.avg, top1.avg)
コード例 #6
0
def validate(val_loader, net, criterion, optim, curr_epoch, writer):
    """
    Runs the validation loop after each training epoch
    val_loader: Data loader for validation
    net: thet network
    criterion: loss fn
    optimizer: optimizer
    curr_epoch: current epoch
    writer: tensorboard writer
    return: val_avg for step function if required
    """

    net.eval()
    val_loss = AverageMeter()
    iou_acc = 0
    dump_images = []

    for val_idx, data in enumerate(val_loader):
        # input        = torch.Size([1, 3, 713, 713])
        # gt_image           = torch.Size([1, 713, 713])
        inputs, gt_image, img_names = data
        assert len(inputs.size()) == 4 and len(gt_image.size()) == 3
        assert inputs.size()[2:] == gt_image.size()[1:]

        batch_pixel_size = inputs.size(0) * inputs.size(2) * inputs.size(3)
        inputs, gt_cuda = inputs.cuda(), gt_image.cuda()

        with torch.no_grad():
            output = net(inputs)  # output = (1, 19, 713, 713)

        assert output.size()[2:] == gt_image.size()[1:]
        assert output.size()[1] == args.dataset_cls.num_classes

        val_loss.update(criterion(output, gt_cuda).item(), batch_pixel_size)

        # Collect data from different GPU to a single GPU since
        # encoding.parallel.criterionparallel function calculates distributed loss
        # functions
        predictions = output.data.max(1)[1].cpu()

        # Logging
        if val_idx % 20 == 0:
            if args.local_rank == 0:
                logging.info("validating: %d / %d", val_idx + 1,
                             len(val_loader))
        if val_idx > 10 and args.test_mode:
            break

        # Image Dumps
        if val_idx < 10:
            dump_images.append([gt_image, predictions, img_names])

        iou_acc += fast_hist(predictions.numpy().flatten(),
                             gt_image.numpy().flatten(),
                             args.dataset_cls.num_classes)
        del output, val_idx, data

    if args.apex:
        iou_acc_tensor = torch.cuda.FloatTensor(iou_acc)
        torch.distributed.all_reduce(iou_acc_tensor,
                                     op=torch.distributed.ReduceOp.SUM)
        iou_acc = iou_acc_tensor.cpu().numpy()

    if args.local_rank == 0:
        evaluate_eval(args, net, optim, val_loss, iou_acc, dump_images, writer,
                      curr_epoch, args.dataset_cls)

    return val_loss.avg
コード例 #7
0
def validate_one_epoch(val_loader, model, epoch, configs, logger):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')

    progress = ProgressMeter(len(val_loader), [batch_time, data_time, losses],
                             prefix="Validation - Epoch: [{}]".format(epoch))
    # switch to evaluate mode
    model.eval()
    with torch.no_grad():
        start_time = time.time()
        for batch_idx, (origin_imgs, resized_imgs, org_ball_pos_xy,
                        global_ball_pos_xy, event_class,
                        target_seg) in enumerate(tqdm(val_loader)):
            data_time.update(time.time() - start_time)
            batch_size = resized_imgs.size(0)
            target_seg = target_seg.to(configs.device, non_blocking=True)
            resized_imgs = resized_imgs.to(configs.device,
                                           non_blocking=True).float()
            # Only move origin_imgs to cuda if the model has local stage for ball detection
            if not configs.no_local:
                origin_imgs = origin_imgs.to(configs.device,
                                             non_blocking=True).float()
                # compute output
                pred_ball_global, pred_ball_local, pred_events, pred_seg, local_ball_pos_xy, total_loss, _ = model(
                    origin_imgs, resized_imgs, org_ball_pos_xy,
                    global_ball_pos_xy, event_class, target_seg)
            else:
                pred_ball_global, pred_ball_local, pred_events, pred_seg, local_ball_pos_xy, total_loss, _ = model(
                    None, resized_imgs, org_ball_pos_xy, global_ball_pos_xy,
                    event_class, target_seg)
            # For torch.nn.DataParallel case
            if (not configs.distributed) and (configs.gpu_idx is None):
                total_loss = torch.mean(total_loss)

            losses.update(total_loss.item(), batch_size)
            # measure elapsed time
            batch_time.update(time.time() - start_time)

            # Log message
            if logger is not None:
                if ((batch_idx + 1) % configs.print_freq) == 0:
                    logger.info(progress.get_message(batch_idx))

            start_time = time.time()

    return losses.avg
コード例 #8
0
def validate_topn(val_loader, net, criterion, optim, epoch, args):
    """
    Find worse case failures ...

    Only single GPU for now

    First pass = calculate TP, FP, FN pixels per image per class
      Take these stats and determine the top20 images to dump per class
    Second pass = dump all those selected images
    """
    assert args.bs_val == 1

    ######################################################################
    # First pass
    ######################################################################
    logx.msg('First pass')
    image_metrics = {}

    net.eval()
    val_loss = AverageMeter()
    iou_acc = 0

    for val_idx, data in enumerate(val_loader):

        # Run network
        assets, _iou_acc = \
            run_minibatch(data, net, criterion, val_loss, True, args, val_idx)

        # per-class metrics
        input_images, labels, img_names, _ = data

        fp, fn = metrics_per_image(_iou_acc)
        img_name = img_names[0]
        image_metrics[img_name] = (fp, fn)

        iou_acc += _iou_acc

        if val_idx % 20 == 0:
            logx.msg(f'validating[Iter: {val_idx + 1} / {len(val_loader)}]')

        if val_idx > 5 and args.test_mode:
            break

    eval_metrics(iou_acc, args, net, optim, val_loss, epoch)

    ######################################################################
    # Find top 20 worst failures from a pixel count perspective
    ######################################################################
    from collections import defaultdict
    worst_images = defaultdict(dict)
    class_to_images = defaultdict(dict)
    for classid in range(cfg.DATASET.NUM_CLASSES):
        tbl = {}
        for img_name in image_metrics.keys():
            fp, fn = image_metrics[img_name]
            fp = fp[classid]
            fn = fn[classid]
            tbl[img_name] = fp + fn
        worst = sorted(tbl, key=tbl.get, reverse=True)
        for img_name in worst[:args.dump_topn]:
            fail_pixels = tbl[img_name]
            worst_images[img_name][classid] = fail_pixels
            class_to_images[classid][img_name] = fail_pixels
    msg = str(worst_images)
    logx.msg(msg)

    # write out per-gpu jsons
    # barrier
    # make single table

    ######################################################################
    # 2nd pass
    ######################################################################
    logx.msg('Second pass')
    attn_map = None

    for val_idx, data in enumerate(val_loader):
        in_image, gt_image, img_names, _ = data

        # Only process images that were identified in first pass
        if not args.dump_topn_all and img_names[0] not in worst_images:
            continue

        with torch.no_grad():
            inputs = in_image.cuda()
            inputs = {'images': inputs, 'gts': gt_image}

            if cfg.MODEL.MSCALE:
                output, attn_map = net(inputs)
            else:
                output = net(inputs)

        output = torch.nn.functional.softmax(output, dim=1)
        prob_mask, predictions = output.data.max(1)
        predictions = predictions.cpu()

        # this has shape [bs, h, w]
        img_name = img_names[0]
        for classid in worst_images[img_name].keys():

            err_mask = calc_err_mask(predictions.numpy(), gt_image.numpy(),
                                     cfg.DATASET.NUM_CLASSES, classid)

            class_name = cfg.DATASET_INST.trainid_to_name[classid]
            error_pixels = worst_images[img_name][classid]
            logx.msg(f'{img_name} {class_name}: {error_pixels}')
            img_names = [img_name + f'_{class_name}']

            to_dump = {
                'gt_images': gt_image,
                'input_images': in_image,
                'predictions': predictions.numpy(),
                'err_mask': err_mask,
                'prob_mask': prob_mask,
                'img_names': img_names
            }

            if attn_map is not None:
                to_dump['attn_maps'] = attn_map

            # FIXME!
            # do_dump_images([to_dump])

    html_fn = os.path.join(args.result_dir, 'best_images',
                           'topn_failures.html')
    from utils.results_page import ResultsPage
    ip = ResultsPage('topn failures', html_fn)
    for classid in class_to_images:
        class_name = cfg.DATASET_INST.trainid_to_name[classid]
        img_dict = class_to_images[classid]
        for img_name in sorted(img_dict, key=img_dict.get, reverse=True):
            fail_pixels = class_to_images[classid][img_name]
            img_cls = f'{img_name}_{class_name}'
            pred_fn = f'{img_cls}_prediction.png'
            gt_fn = f'{img_cls}_gt.png'
            inp_fn = f'{img_cls}_input.png'
            err_fn = f'{img_cls}_err_mask.png'
            prob_fn = f'{img_cls}_prob_mask.png'
            img_label_pairs = [(pred_fn, 'pred'), (gt_fn, 'gt'),
                               (inp_fn, 'input'), (err_fn, 'errors'),
                               (prob_fn, 'prob')]
            ip.add_table(img_label_pairs,
                         table_heading=f'{class_name}-{fail_pixels}')
    ip.write_page()

    return val_loss.avg
コード例 #9
0
    def train_epoch(self, epoch_num):
        batch_time = AverageMeter()
        losses_edge = AverageMeter()
        losses_corner = AverageMeter()

        self.model.train()

        end = time.time()
        for iter_i, batch_data in enumerate(self.train_loader):
            image_inputs = batch_data['image']
            if self.mode == 'corner':
                corner_target_maps = batch_data['corner_gt_map']
                edge_target_maps = batch_data['edge_gt_map']
                room_masks_map = batch_data['room_masks_map']
            else:
                raise ValueError('Invalid mode {}'.format(self.mode))

            mean_normal = batch_data['mean_normal']

            # contour_image = batch_data['contour_image']

            if self.configs.use_cuda:
                image_inputs = image_inputs.cuda()
                mean_normal = mean_normal.cuda()
                corner_target_maps = corner_target_maps.cuda()
                edge_target_maps = edge_target_maps.cuda()

            room_masks_map = room_masks_map.cuda()
            inputs = torch.cat([
                image_inputs.unsqueeze(1), mean_normal,
                room_masks_map.unsqueeze(1)
            ],
                               dim=1)

            corner_preds_logits, edge_preds_logits, edge_preds, corner_preds = self.model(
                inputs)

            # # mask the binning part, only predicting directions for places with corners
            loss_mask_c = corner_target_maps[:, 0, :, :].clone().unsqueeze(
                1) * 4 + 1
            loss_c = self.criterion(corner_preds_logits, corner_target_maps)
            loss_c = loss_c * loss_mask_c
            loss_c = loss_c.mean(2).mean(2).mean(
                0).sum()  # take mean over batch, H, W, sum over C

            loss_mask_e = edge_target_maps[:, 0, :, :].clone().unsqueeze(
                1) * 4 + 1
            loss_e = self.criterion(edge_preds_logits, edge_target_maps)
            loss_e = loss_e * loss_mask_e
            loss_e = loss_e.mean(2).mean(2).mean(
                0).sum()  # take mean over batch, H, W, sum over C

            loss = loss_e + loss_c

            losses_edge.update(loss_e.data, image_inputs.size(0))
            losses_corner.update(loss_c.data, image_inputs.size(0))

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Edge pred Loss {loss1.val:.4f} ({loss1.avg:.4f})\t'
                  'Corner pred Loss {loss2.val:.4f} ({loss2.avg:.4f})'.format(
                      epoch_num,
                      iter_i,
                      len(self.train_loader),
                      batch_time=batch_time,
                      loss1=losses_edge,
                      loss2=losses_corner))

            if iter_i % self.configs.visualize_iter == 0:
                viz_dir = os.path.join(self.configs.exp_dir, 'training_viz')
                gt_file_path = os.path.join(
                    viz_dir,
                    'epoch_{}_iter_{}_gt.png'.format(epoch_num, iter_i))
                gt_edge_file_path = os.path.join(
                    viz_dir,
                    'epoch_{}_iter_{}_gt_edge.png'.format(epoch_num, iter_i))
                heatmap_path = os.path.join(
                    viz_dir,
                    'epoch_{}_iter_{}_preds.png'.format(epoch_num, iter_i))
                heatmap_edge_path = os.path.join(
                    viz_dir, 'epoch_{}_iter_{}_preds_edge.png'.format(
                        epoch_num, iter_i))
                # corner_edge_path = os.path.join(viz_dir, 'epoch_{}_iter_{}_corner_edge.png'.format(epoch_num, iter_i))
                gt_map_path = os.path.join(
                    viz_dir, 'epoch_{}_iter_{}_gt_corner_edge.png'.format(
                        epoch_num, iter_i))
                # simply use the first element in the batch
                corner_preds_np = corner_preds[0].detach().cpu().numpy()
                edge_preds_np = edge_preds[0].detach().cpu().numpy()
                edge_gt_np = edge_target_maps[0].cpu().numpy()
                corner_gt_np = corner_target_maps[0].cpu().numpy()
                if self.mode == 'corner':
                    _, gt_corner_edge_map = get_corner_dir_map(
                        corner_gt_np, 256)
                    imsave(gt_map_path, gt_corner_edge_map)
                imsave(gt_file_path, corner_gt_np[0])
                imsave(heatmap_path, corner_preds_np[0])
                imsave(gt_edge_file_path, edge_gt_np[0])
                imsave(heatmap_edge_path, edge_preds_np[0])
コード例 #10
0
ファイル: train.py プロジェクト: shachoi/RobustNet
def validate(val_loader,
             dataset,
             net,
             criterion,
             optim,
             scheduler,
             curr_epoch,
             writer,
             curr_iter,
             save_pth=True):
    """
    Runs the validation loop after each training epoch
    val_loader: Data loader for validation
    dataset: dataset name (str)
    net: thet network
    criterion: loss fn
    optimizer: optimizer
    curr_epoch: current epoch
    writer: tensorboard writer
    return: val_avg for step function if required
    """

    net.eval()
    val_loss = AverageMeter()
    iou_acc = 0
    error_acc = 0
    dump_images = []

    for val_idx, data in enumerate(val_loader):
        # input        = torch.Size([1, 3, 713, 713])
        # gt_image           = torch.Size([1, 713, 713])
        inputs, gt_image, img_names, _ = data

        if len(inputs.shape) == 5:
            B, D, C, H, W = inputs.shape
            inputs = inputs.view(-1, C, H, W)
            gt_image = gt_image.view(-1, 1, H, W)

        assert len(inputs.size()) == 4 and len(gt_image.size()) == 3
        assert inputs.size()[2:] == gt_image.size()[1:]

        batch_pixel_size = inputs.size(0) * inputs.size(2) * inputs.size(3)
        inputs, gt_cuda = inputs.cuda(), gt_image.cuda()

        with torch.no_grad():
            if args.use_wtloss:
                output, f_cor_arr = net(inputs, visualize=True)
            else:
                output = net(inputs)

        del inputs

        assert output.size()[2:] == gt_image.size()[1:]
        assert output.size()[1] == datasets.num_classes

        val_loss.update(criterion(output, gt_cuda).item(), batch_pixel_size)

        del gt_cuda

        # Collect data from different GPU to a single GPU since
        # encoding.parallel.criterionparallel function calculates distributed loss
        # functions
        predictions = output.data.max(1)[1].cpu()

        # Logging
        if val_idx % 20 == 0:
            if args.local_rank == 0:
                logging.info("validating: %d / %d", val_idx + 1,
                             len(val_loader))
        if val_idx > 10 and args.test_mode:
            break

        # Image Dumps
        if val_idx < 10:
            dump_images.append([gt_image, predictions, img_names])

        iou_acc += fast_hist(predictions.numpy().flatten(),
                             gt_image.numpy().flatten(), datasets.num_classes)
        del output, val_idx, data

    iou_acc_tensor = torch.cuda.FloatTensor(iou_acc)
    torch.distributed.all_reduce(iou_acc_tensor,
                                 op=torch.distributed.ReduceOp.SUM)
    iou_acc = iou_acc_tensor.cpu().numpy()

    if args.local_rank == 0:
        evaluate_eval(args,
                      net,
                      optim,
                      scheduler,
                      val_loss,
                      iou_acc,
                      dump_images,
                      writer,
                      curr_epoch,
                      dataset,
                      None,
                      curr_iter,
                      save_pth=save_pth)

        if args.use_wtloss:
            visualize_matrix(writer, f_cor_arr, curr_iter,
                             '/Covariance/Feature-')

    return val_loss.avg
コード例 #11
0
ファイル: main.py プロジェクト: ojasgr1706/PyTorch-MNIST
def train():
    try:
        os.makedirs(opt.checkpoints_dir)
    except OSError:
        pass

    CNN.to(device)
    CNN.train()
    torchsummary.summary(CNN, (1, 28, 28))

    ################################################
    # Set loss function and Adam optimier
    ################################################
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = optim.Adam(CNN.parameters(), lr=opt.lr)

    for epoch in range(opt.epochs):
        # train for one epoch
        print(f"\nBegin Training Epoch {epoch + 1}")
        # Calculate and return the top-k accuracy of the model
        # so that we can track the learning process.
        batch_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()

        end = time.time()
        for i, data in enumerate(train_dataloader):
            # get the inputs; data is a list of [inputs, labels]
            inputs, targets = data
            inputs = inputs.to(device)
            targets = targets.to(device)

            # compute output
            output = CNN(inputs)
            loss = criterion(output, targets)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output, targets, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1, inputs.size(0))
            top5.update(prec5, inputs.size(0))

            # compute gradients in a backward pass
            optimizer.zero_grad()
            loss.backward()

            # Call step of optimizer to update model params
            optimizer.step()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % 15 == 0:
                print(
                    f"Epoch [{epoch + 1}] [{i}/{len(train_dataloader)}]\t"
                    f"Loss {loss.item():.4f}\t"
                    f"Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t"
                    f"Prec@5 {top5.val:.3f} ({top5.avg:.3f})",
                    end="\r")

        # save model file
        torch.save(CNN.state_dict(), MODEL_PATH)
コード例 #12
0
ファイル: train.py プロジェクト: shachoi/RobustNet
def train(train_loader, net, optim, curr_epoch, writer, scheduler, max_iter):
    """
    Runs the training loop per epoch
    train_loader: Data loader for train
    net: thet network
    optimizer: optimizer
    curr_epoch: current epoch
    writer: tensorboard writer
    return:
    """
    net.train()

    train_total_loss = AverageMeter()
    time_meter = AverageMeter()

    curr_iter = curr_epoch * len(train_loader)

    for i, data in enumerate(train_loader):
        if curr_iter >= max_iter:
            break

        inputs, gts, _, aux_gts = data

        # Multi source and AGG case
        if len(inputs.shape) == 5:
            B, D, C, H, W = inputs.shape
            num_domains = D
            inputs = inputs.transpose(0, 1)
            gts = gts.transpose(0, 1).squeeze(2)
            aux_gts = aux_gts.transpose(0, 1).squeeze(2)

            inputs = [
                input.squeeze(0)
                for input in torch.chunk(inputs, num_domains, 0)
            ]
            gts = [gt.squeeze(0) for gt in torch.chunk(gts, num_domains, 0)]
            aux_gts = [
                aux_gt.squeeze(0)
                for aux_gt in torch.chunk(aux_gts, num_domains, 0)
            ]
        else:
            B, C, H, W = inputs.shape
            num_domains = 1
            inputs = [inputs]
            gts = [gts]
            aux_gts = [aux_gts]

        batch_pixel_size = C * H * W

        for di, ingredients in enumerate(zip(inputs, gts, aux_gts)):
            input, gt, aux_gt = ingredients

            start_ts = time.time()

            img_gt = None
            input, gt = input.cuda(), gt.cuda()

            optim.zero_grad()
            if args.use_isw:
                outputs = net(input,
                              gts=gt,
                              aux_gts=aux_gt,
                              img_gt=img_gt,
                              visualize=args.visualize_feature,
                              apply_wtloss=False
                              if curr_epoch <= args.cov_stat_epoch else True)
            else:
                outputs = net(input,
                              gts=gt,
                              aux_gts=aux_gt,
                              img_gt=img_gt,
                              visualize=args.visualize_feature)
            outputs_index = 0
            main_loss = outputs[outputs_index]
            outputs_index += 1
            aux_loss = outputs[outputs_index]
            outputs_index += 1
            total_loss = main_loss + (0.4 * aux_loss)

            if args.use_wtloss and (not args.use_isw or
                                    (args.use_isw
                                     and curr_epoch > args.cov_stat_epoch)):
                wt_loss = outputs[outputs_index]
                outputs_index += 1
                total_loss = total_loss + (args.wt_reg_weight * wt_loss)
            else:
                wt_loss = 0

            if args.visualize_feature:
                f_cor_arr = outputs[outputs_index]
                outputs_index += 1

            log_total_loss = total_loss.clone().detach_()
            torch.distributed.all_reduce(log_total_loss,
                                         torch.distributed.ReduceOp.SUM)
            log_total_loss = log_total_loss / args.world_size
            train_total_loss.update(log_total_loss.item(), batch_pixel_size)

            total_loss.backward()
            optim.step()

            time_meter.update(time.time() - start_ts)

            del total_loss, log_total_loss

            if args.local_rank == 0:
                if i % 50 == 49:
                    if args.visualize_feature:
                        visualize_matrix(writer, f_cor_arr, curr_iter,
                                         '/Covariance/Feature-')

                    msg = '[epoch {}], [iter {} / {} : {}], [loss {:0.6f}], [lr {:0.6f}], [time {:0.4f}]'.format(
                        curr_epoch, i + 1, len(train_loader), curr_iter,
                        train_total_loss.avg, optim.param_groups[-1]['lr'],
                        time_meter.avg / args.train_batch_size)

                    logging.info(msg)
                    if args.use_wtloss:
                        print("Whitening Loss", wt_loss)

                    # Log tensorboard metrics for each iteration of the training phase
                    writer.add_scalar('loss/train_loss',
                                      (train_total_loss.avg), curr_iter)
                    train_total_loss.reset()
                    time_meter.reset()

        curr_iter += 1
        scheduler.step()

        if i > 5 and args.test_mode:
            return curr_iter

    return curr_iter
コード例 #13
0
def test(args, model, video_val=None):
    reward_avg = AverageMeter()
    loss_avg = AverageMeter()
    value_loss_avg = AverageMeter()
    policy_loss_avg = AverageMeter()

    root_dir = '/home/youngfly/DL_project/RL_Tracking/dataset/VOT'
    data_type = 'VOT'

    model.eval()
    env = Env(seqs_path=root_dir,
              data_set_type=data_type,
              save_path='/dataset/Result/VOT')

    for video_name in video_val:

        actions = []
        rewards = []
        values = []
        entropies = []
        logprobs = []

        # reset for new video
        observation1, observation2 = env.reset(video_name)
        img1 = ReadSingleImage(observation2)
        img1 = Variable(img1).cuda()

        hidden_prev = model.init_hidden_state(
            batch_size=1)  # variable cuda tensor
        _, _, _, _, hidden_pres = model(imgs=img1, hidden_prev=hidden_prev)

        # for loop init parameter
        hidden_prev = hidden_pres
        observation = observation2
        FLAG = 1
        i = 2
        while FLAG:
            img = ReadSingleImage(observation)
            img = Variable(img).cuda()

            action_prob, action_logprob, action_sample, value, hidden_pres = model(
                imgs=img, hidden_prev=hidden_prev)

            entropy = -(action_logprob * action_prob).sum(1, keepdim=True)
            entropies.append(entropy)

            actions.append(action_sample.long())  # list, Variable cuda inner
            action_np = action_sample.data.cpu().numpy()

            sample = Variable(torch.LongTensor(action_np).cuda()).unsqueeze(0)

            hidden_prev = hidden_pres
            logprob = action_logprob.gather(1, sample)
            logprobs.append(logprob)

            reward, new_observation, done = env.step(action=action_np)
            env.show_all()
            # env.show_tracking_result()

            print(
                'test:', 'frame:%d' % (i), 'Action:%d' % action_np[0],
                'rewards:%.6f' % reward, 'probability:%.6f, %.6f' %
                (action_prob.data.cpu().numpy()[0, 0],
                 action_prob.data.cpu().numpy()[0, 1]))
            i = i + 1
            rewards.append(reward)  # just list
            values.append(value)  # list, Variable cuda inner
            observation = new_observation

            if done:
                FLAG = 0

        num_seqs = len(rewards)
        running_add = Variable(torch.FloatTensor([0])).cuda()
        value_loss = 0
        policy_loss = 0
        gae = torch.FloatTensor([0]).cuda()
        values.append(running_add)
        for i in reversed(range(len(rewards))):
            # if rewards[i] < 0.2:
            #     rewards[i] = rewards[i] ** 2
            running_add = args.gamma * running_add + rewards[i]
            advantage = running_add - values[i]
            value_loss = value_loss + 0.5 * advantage.pow(2)

            delta_t = rewards[i] + args.gamma * values[i +
                                                       1].data - values[i].data
            gae = gae * args.gamma * args.tau + delta_t

            policy_loss = policy_loss - logprobs[i] * Variable(
                gae) - args.entropy_coef * entropies[i]

        # value_loss = value_loss / num_seqs
        # policy_loss = policy_loss / num_seqs
        #
        # values.append(running_add)
        # for i in reversed(range(len(rewards))):
        #     running_add = args.gamma * running_add + rewards[i]
        #     advantage = running_add - values[i]
        #     value_loss = value_loss + 0.5 * advantage.pow(2)
        #     policy_loss = policy_loss - logprobs[i] * advantage - args.entropy_coef * entropies[i]
        #
        value_loss = value_loss / num_seqs
        policy_loss = policy_loss / num_seqs

        loss = args.value_loss_coef * value_loss + policy_loss

        print(video_name, 'rewards:%.6f' % np.mean(rewards),
              'loss:%.6f' % loss.data[0],
              'value_loss:%6f' % value_loss.data[0],
              'policy_loss:%.6f' % policy_loss.data[0])

        # update the loss
        loss_avg.update(loss.data.cpu().numpy())
        value_loss_avg.update(value_loss.data.cpu().numpy())
        policy_loss_avg.update(policy_loss.data.cpu().numpy())
        reward_avg.update(np.mean(rewards))

    return reward_avg.avg, loss_avg.avg, value_loss_avg.avg, policy_loss_avg.avg
def test(test_loader, model, configs):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    iou_seg = AverageMeter('IoU_Seg', ':6.4f')
    mse_global = AverageMeter('MSE_Global', ':6.4f')
    mse_local = AverageMeter('MSE_Local', ':6.4f')
    mse_overall = AverageMeter('MSE_Overall', ':6.4f')
    pce = AverageMeter('PCE', ':6.4f')
    spce = AverageMeter('Smooth_PCE', ':6.4f')
    w_original = 1920.
    h_original = 1080.
    w, h = configs.input_size

    # switch to evaluate mode
    model.eval()
    with torch.no_grad():
        start_time = time.time()
        for batch_idx, (resized_imgs, org_ball_pos_xy, global_ball_pos_xy,
                        target_events,
                        target_seg) in enumerate(tqdm(test_loader)):

            print(
                '\n===================== batch_idx: {} ================================'
                .format(batch_idx))

            data_time.update(time.time() - start_time)
            batch_size = resized_imgs.size(0)
            target_seg = target_seg.to(configs.device, non_blocking=True)
            resized_imgs = resized_imgs.to(configs.device,
                                           non_blocking=True).float()
            # compute output

            pred_ball_global, pred_ball_local, pred_events, pred_seg, local_ball_pos_xy, total_loss, _ = model(
                resized_imgs, org_ball_pos_xy, global_ball_pos_xy,
                target_events, target_seg)

            org_ball_pos_xy = org_ball_pos_xy.numpy()
            global_ball_pos_xy = global_ball_pos_xy.numpy()
            # Transfer output to cpu
            target_seg = target_seg.cpu().numpy()

            for sample_idx in range(batch_size):
                # Get target
                sample_org_ball_pos_xy = org_ball_pos_xy[sample_idx]
                sample_global_ball_pos_xy = global_ball_pos_xy[
                    sample_idx]  # Target
                # Process the global stage
                sample_pred_ball_global = pred_ball_global[sample_idx]
                sample_prediction_ball_global_xy = get_prediction_ball_pos(
                    sample_pred_ball_global, w, configs.thresh_ball_pos_mask)

                # Calculate the MSE
                if (sample_global_ball_pos_xy[0] >
                        0) and (sample_global_ball_pos_xy[1] > 0) and (
                            sample_prediction_ball_global_xy[0] >
                            0) and (sample_prediction_ball_global_xy[1] > 0):
                    mse = (sample_prediction_ball_global_xy[0] - sample_global_ball_pos_xy[0]) ** 2 + \
                          (sample_prediction_ball_global_xy[1] - sample_global_ball_pos_xy[1]) ** 2
                    mse_global.update(mse)

                print(
                    '\nBall Detection - \t Global stage: \t (x, y) - gt = ({}, {}), prediction = ({}, {})'
                    .format(sample_global_ball_pos_xy[0],
                            sample_global_ball_pos_xy[1],
                            sample_prediction_ball_global_xy[0],
                            sample_prediction_ball_global_xy[1]))

                sample_pred_org_x = sample_prediction_ball_global_xy[0] * (
                    w_original / w)
                sample_pred_org_y = sample_prediction_ball_global_xy[1] * (
                    h_original / h)

                # Process local ball stage
                if pred_ball_local is not None:
                    # Get target
                    local_ball_pos_xy = local_ball_pos_xy.cpu().numpy(
                    )  # Ground truth of the local stage
                    sample_local_ball_pos_xy = local_ball_pos_xy[
                        sample_idx]  # Target
                    # Process the local stage
                    sample_pred_ball_local = pred_ball_local[sample_idx]
                    sample_prediction_ball_local_xy = get_prediction_ball_pos(
                        sample_pred_ball_local, w,
                        configs.thresh_ball_pos_mask)

                    # Calculate the MSE
                    if (sample_local_ball_pos_xy[0] >
                            0) and (sample_local_ball_pos_xy[1] > 0):
                        mse = (sample_prediction_ball_local_xy[0] -
                               sample_local_ball_pos_xy[0])**2 + (
                                   sample_prediction_ball_local_xy[1] -
                                   sample_local_ball_pos_xy[1])**2
                        mse_local.update(mse)
                        sample_pred_org_x += sample_prediction_ball_local_xy[
                            0] - w / 2
                        sample_pred_org_y += sample_prediction_ball_local_xy[
                            1] - h / 2

                    print(
                        'Ball Detection - \t Local stage: \t (x, y) - gt = ({}, {}), prediction = ({}, {})'
                        .format(sample_local_ball_pos_xy[0],
                                sample_local_ball_pos_xy[1],
                                sample_prediction_ball_local_xy[0],
                                sample_prediction_ball_local_xy[1]))

                print(
                    'Ball Detection - \t Overall: \t (x, y) - org: ({}, {}), prediction = ({}, {})'
                    .format(sample_org_ball_pos_xy[0],
                            sample_org_ball_pos_xy[1], int(sample_pred_org_x),
                            int(sample_pred_org_y)))
                mse = (sample_org_ball_pos_xy[0] - sample_pred_org_x)**2 + (
                    sample_org_ball_pos_xy[1] - sample_pred_org_y)**2
                mse_overall.update(mse)

                # Process event stage
                if pred_events is not None:
                    sample_target_events = target_events[sample_idx].numpy()
                    sample_prediction_events = prediction_get_events(
                        pred_events[sample_idx], configs.event_thresh)
                    print(
                        'Event Spotting - \t gt = (is bounce: {}, is net: {}), prediction: (is bounce: {:.4f}, is net: {:.4f})'
                        .format(sample_target_events[0],
                                sample_target_events[1],
                                pred_events[sample_idx][0],
                                pred_events[sample_idx][1]))
                    # Compute metrics
                    spce.update(
                        SPCE(sample_prediction_events,
                             sample_target_events,
                             thresh=0.5))
                    pce.update(
                        PCE(sample_prediction_events, sample_target_events))

                # Process segmentation stage
                if pred_seg is not None:
                    sample_target_seg = target_seg[sample_idx].transpose(
                        1, 2, 0).astype(np.int)
                    sample_prediction_seg = get_prediction_seg(
                        pred_seg[sample_idx], configs.seg_thresh)

                    # Calculate the IoU
                    iou = 2 * np.sum(
                        sample_target_seg * sample_prediction_seg) / (
                            np.sum(sample_target_seg) +
                            np.sum(sample_prediction_seg) + 1e-9)
                    iou_seg.update(iou)

                    print('Segmentation - \t \t IoU = {:.4f}'.format(iou))

                    if configs.save_test_output:
                        fig, axes = plt.subplots(nrows=batch_size,
                                                 ncols=2,
                                                 figsize=(10, 5))
                        plt.tight_layout()
                        axes.ravel()
                        axes[2 * sample_idx].imshow(sample_target_seg * 255)
                        axes[2 * sample_idx + 1].imshow(sample_prediction_seg *
                                                        255)
                        # title
                        target_title = 'target seg'
                        pred_title = 'pred seg'
                        if pred_events is not None:
                            target_title += ', is bounce: {}, is net: {}'.format(
                                sample_target_events[0],
                                sample_target_events[1])
                            pred_title += ', is bounce: {}, is net: {}'.format(
                                sample_prediction_events[0],
                                sample_prediction_events[1])

                        axes[2 * sample_idx].set_title(target_title)
                        axes[2 * sample_idx + 1].set_title(pred_title)

                        plt.savefig(
                            os.path.join(
                                configs.saved_dir,
                                'batch_idx_{}_sample_idx_{}.jpg'.format(
                                    batch_idx, sample_idx)))

            if ((batch_idx + 1) % configs.print_freq) == 0:
                print(
                    'batch_idx: {} - Average iou_seg: {:.4f}, mse_global: {:.1f}, mse_local: {:.1f}, mse_overall: {:.1f}, pce: {:.4f} spce: {:.4f}'
                    .format(batch_idx, iou_seg.avg, mse_global.avg,
                            mse_local.avg, mse_overall.avg, pce.avg, spce.avg))

            batch_time.update(time.time() - start_time)
            start_time = time.time()

    print(
        'Average iou_seg: {:.4f}, mse_global: {:.1f}, mse_local: {:.1f}, mse_overall: {:.1f}, pce: {:.4f} spce: {:.4f}'
        .format(iou_seg.avg, mse_global.avg, mse_local.avg, mse_overall.avg,
                pce.avg, spce.avg))
    print('Done testing')
コード例 #15
0
ファイル: evaluate.py プロジェクト: wangx1996/CenterPillarNet
def evaluate_mAP(val_loader, model, configs, logger):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')

    progress = ProgressMeter(len(val_loader), [batch_time, data_time],
                             prefix="Evaluation phase...")
    labels = []
    sample_metrics = []  # List of tuples (TP, confs, pred)
    # switch to evaluate mode
    model.eval()
    with torch.no_grad():
        start_time = time.time()
        for batch_idx, batch_data in enumerate(tqdm(val_loader)):
            metadatas, targets = batch_data

            batch_size = len(metadatas['img_path'])

            voxelinput = metadatas['voxels']
            coorinput = metadatas['coors']
            numinput = metadatas['num_points']

            dtype = torch.float32
            voxelinputr = torch.tensor(voxelinput,
                                       dtype=torch.float32,
                                       device=configs.device).to(dtype)

            coorinputr = torch.tensor(coorinput,
                                      dtype=torch.int32,
                                      device=configs.device)

            numinputr = torch.tensor(numinput,
                                     dtype=torch.int32,
                                     device=configs.device)
            t1 = time_synchronized()
            outputs = model(voxelinputr, coorinputr, numinputr)
            outputs = outputs._asdict()

            outputs['hm_cen'] = _sigmoid(outputs['hm_cen'])
            outputs['cen_offset'] = _sigmoid(outputs['cen_offset'])
            # detections size (batch_size, K, 10)

            detections = decode(outputs['hm_cen'],
                                outputs['cen_offset'],
                                outputs['direction'],
                                outputs['z_coor'],
                                outputs['dim'],
                                K=configs.K)
            detections = detections.cpu().numpy().astype(np.float32)
            detections = post_processingv2(detections, configs.num_classes,
                                           configs.down_ratio,
                                           configs.peak_thresh)

            for sample_i in range(len(detections)):
                # print(output.shape)
                num = targets['count'][sample_i]
                # print(targets['batch'][sample_i][:num].shape)
                target = targets['batch'][sample_i][:num]
                #print(target[:, 8].tolist())
                labels += target[:, 8].tolist()

            sample_metrics += get_batch_statistics_rotated_bbox(
                detections, targets, iou_threshold=configs.iou_thresh)

            t2 = time_synchronized()

            # measure elapsed time
            # torch.cuda.synchronize()
            batch_time.update(time.time() - start_time)

            # Log message
            if logger is not None:
                if ((batch_idx + 1) % configs.print_freq) == 0:
                    logger.info(progress.get_message(batch_idx))

            start_time = time.time()

        # Concatenate sample statistics
        true_positives, pred_scores, pred_labels = [
            np.concatenate(x, 0) for x in list(zip(*sample_metrics))
        ]
        precision, recall, AP, f1, ap_class = ap_per_class(
            true_positives, pred_scores, pred_labels, labels)

    return precision, recall, AP, f1, ap_class
コード例 #16
0
def evaluate(val_loader, net):
    '''
    Runs the evaluation loop and prints F score
    val_loader: Data loader for validation
    net: thet network
    return: 
    '''
    net.eval()
    # 0.0005   13.0 it/sec
    # 0.001875 4.80 it/sec
    # 0.00375  1.70 it/sec
    # 0.005    1.03 it/sec
    thresh = 0.0001

    mf_score1 = AverageMeter()
    mf_pc_score1 = AverageMeter()
    ap_score1 = AverageMeter()
    ap_pc_score1 = AverageMeter()
    IOU_acc = 0
    Fpc = np.zeros((args.dataset_cls.num_classes))
    Fc = np.zeros((args.dataset_cls.num_classes))
    for vi, data in enumerate(val_loader):
        input, mask, edge, img_names = data
        assert len(input.size()) == 4 and len(mask.size()) == 3
        assert input.size()[2:] == mask.size()[1:]
        h, w = mask.size()[1:]

        batch_pixel_size = input.size(0) * input.size(2) * input.size(3)
        input, mask_cuda, edge_cuda = input.cuda(), mask.cuda(), edge.cuda()

        with torch.no_grad():
            seg_out, edge_out = net(input)

        seg_predictions = seg_out.data.max(1)[1].cpu()
        edge_predictions = edge_out.max(1)[0].cpu()

        logging.info('evaluating: %d / %d' % (vi + 1, len(val_loader)))
        '''
        _Fpc, _Fc = eval_mask_boundary(seg_predictions.numpy(), mask.numpy(), args.dataset_cls.num_classes, bound_th=float(thresh))
        Fc += _Fc
        Fpc += _Fpc
        logging.info('F_Score: ' + str(np.sum(Fpc/Fc)/args.dataset_cls.num_classes))
        '''

        IOU_acc += fast_hist(seg_predictions.numpy().flatten(),
                             mask.numpy().flatten(),
                             args.dataset_cls.num_classes)

        del seg_out, edge_out, vi, data

    acc = np.diag(IOU_acc).sum() / IOU_acc.sum()
    acc_cls = np.diag(IOU_acc) / IOU_acc.sum(axis=1)
    acc_cls = np.nanmean(acc_cls)
    iu = np.diag(IOU_acc) / (IOU_acc.sum(axis=1) + IOU_acc.sum(axis=0) -
                             np.diag(IOU_acc))
    freq = IOU_acc.sum(axis=1) / IOU_acc.sum()
    mean_iu = np.nanmean(iu)
    fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()

    #logging.info('F_Score: ' + str(np.sum(Fpc/Fc)/args.dataset_cls.num_classes))
    #logging.info('F_Score (Classwise): ' + str(Fpc/Fc))
    results = {
        "mean_iu": mean_iu,
        "acc": acc,
        "acc_cls": acc_cls,
        "fwavacc": fwavacc
    }

    return results
コード例 #17
0
def train(train_loader, net, optim, curr_epoch, writer):
    """
    Runs the training loop per epoch
    train_loader: Data loader for train
    net: thet network
    optimizer: optimizer
    curr_epoch: current epoch
    writer: tensorboard writer
    return:
    """
    net.train()

    train_main_loss = AverageMeter()
    curr_iter = curr_epoch * len(train_loader)

    for i, data in enumerate(train_loader):
        # inputs = (2,3,713,713)
        # gts    = (2,713,713)
        inputs, gts, _img_name = data

        batch_pixel_size = inputs.size(0) * inputs.size(2) * inputs.size(3)

        inputs, gts = inputs.cuda(), gts.cuda()

        optim.zero_grad()

        main_loss = net(inputs, gts=gts)

        if args.apex and not args.local_computer:
            log_main_loss = main_loss.clone().detach_()
            torch.distributed.all_reduce(log_main_loss,
                                         torch.distributed.ReduceOp.SUM)
            log_main_loss = log_main_loss / args.world_size
        else:
            main_loss = main_loss.mean()
            log_main_loss = main_loss.clone().detach_()

        train_main_loss.update(log_main_loss.item(), batch_pixel_size)
        if args.fp16:  # and 0:
            with amp.scale_loss(main_loss, optim) as scaled_loss:
                scaled_loss.backward()
        else:
            main_loss.backward()

        optim.step()

        curr_iter += 1

        if args.local_rank == 0:
            msg = '[epoch {}], [iter {} / {}], [train main loss {:0.6f}], [lr {:0.6f}]'.format(
                curr_epoch, i + 1, len(train_loader), train_main_loss.avg,
                optim.param_groups[-1]['lr'])

            logging.info(msg)

            # Log tensorboard metrics for each iteration of the training phase
            writer.add_scalar('training/loss', (train_main_loss.val),
                              curr_iter)
            writer.add_scalar('training/lr', optim.param_groups[-1]['lr'],
                              curr_iter)

        if i > 5 and args.test_mode:
            return
コード例 #18
0
def main():
    print(args)

    os.makedirs(args.out, exist_ok=True)
    args.writer = SummaryWriter(args.out)

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # create network
    #model = Res_Deeplab(num_classes=args.num_classes)
    model = DeepV3PlusW38(num_classes=args.num_classes)

    # load pretrained parameters
    saved_state_dict = torch.load(args.restore_from)

    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
    model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    model = torch.nn.DataParallel(model).cuda()
    cudnn.benchmark = True

    # init D
    model_D = s4GAN_discriminator(num_classes=args.num_classes,
                                  dataset=args.dataset)

    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))

    model_D = torch.nn.DataParallel(model_D).cuda()
    cudnn.benchmark = True

    model_D.train()
    model_D.cuda(args.gpu)

    if not os.path.exists(args.checkpoint_dir):
        os.makedirs(args.checkpoint_dir)

    if args.dataset == 'pascal_voc':
        train_dataset = VOCDataSet(args.data_dir,
                                   args.data_list,
                                   crop_size=input_size,
                                   scale=args.random_scale,
                                   mirror=args.random_mirror,
                                   mean=IMG_MEAN)
        #train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size,
        #scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN)

    elif args.dataset == 'pascal_context':
        input_transform = transform.Compose([
            transform.ToTensor(),
            transform.Normalize([.406, .456, .485], [.229, .224, .225])
        ])
        data_kwargs = {
            'transform': input_transform,
            'base_size': 505,
            'crop_size': 321
        }
        #train_dataset = get_segmentation_dataset('pcontext', split='train', mode='train', **data_kwargs)
        data_loader = get_loader('pascal_context')
        data_path = get_data_path('pascal_context')
        train_dataset = data_loader(data_path,
                                    split='train',
                                    mode='train',
                                    **data_kwargs)
        #train_gt_dataset = data_loader(data_path, split='train', mode='train', **data_kwargs)

    elif args.dataset == 'cityscapes':
        data_loader = get_loader('cityscapes')
        data_path = get_data_path('cityscapes')
        data_aug = Compose([
            RandomCrop_city((input_size[0], input_size[1])),
            RandomHorizontallyFlip()
        ])
        train_dataset = data_loader(data_path,
                                    is_transform=True,
                                    img_size=(input_size[0], input_size[1]),
                                    augmentations=data_aug)
        #train_gt_dataset = data_loader( data_path, is_transform=True, augmentations=data_aug)

    elif args.dataset == 'ade20k':
        train_dataset = ADE20K(mode='train', crop_size=input_size)

    train_dataset_size = len(train_dataset)
    print('dataset size: ', train_dataset_size)

    if args.labeled_ratio is None:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=4,
                                      pin_memory=True,
                                      drop_last=True)

        trainloader_gt = data.DataLoader(train_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=4,
                                         pin_memory=True,
                                         drop_last=True)

        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=args.batch_size_unlab,
                                             shuffle=True,
                                             num_workers=4,
                                             pin_memory=True,
                                             drop_last=True)
        trainloader_remain_iter = iter(trainloader_remain)

    else:
        partial_size = int(args.labeled_ratio * train_dataset_size)
        print('labeled data: ', partial_size)
        print('unlabeled data: ', train_dataset_size - partial_size)

        if args.split_id is not None:
            train_ids = pickle.load(open(args.split_id, 'rb'))
            print('loading train ids from {}'.format(args.split_id))
        else:
            train_ids = np.arange(train_dataset_size)
            np.random.shuffle(train_ids)

        pickle.dump(
            train_ids,
            open(os.path.join(args.checkpoint_dir, 'train_voc_split.pkl'),
                 'wb'))

        train_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])
        train_remain_sampler = data.sampler.SubsetRandomSampler(
            train_ids[partial_size:])
        train_gt_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])

        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      sampler=train_sampler,
                                      num_workers=4,
                                      pin_memory=True,
                                      drop_last=True)
        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=args.batch_size_unlab,
                                             sampler=train_remain_sampler,
                                             num_workers=4,
                                             pin_memory=True,
                                             drop_last=True)
        trainloader_gt = data.DataLoader(train_dataset,
                                         batch_size=args.batch_size,
                                         sampler=train_gt_sampler,
                                         num_workers=4,
                                         pin_memory=True,
                                         drop_last=True)

        trainloader_remain_iter = iter(trainloader_remain)

        print('train dataloader created!')

    trainloader_iter = iter(trainloader)
    trainloader_gt_iter = iter(trainloader_gt)

    if args.dataset == 'pascal_voc':
        valloader = data.DataLoader(VOCDataSet(args.data_dir,
                                               args.data_list,
                                               crop_size=(505, 505),
                                               mean=IMG_MEAN,
                                               scale=False,
                                               mirror=False),
                                    batch_size=1,
                                    shuffle=False,
                                    pin_memory=True)
        interp_val = nn.Upsample(size=(505, 505),
                                 mode='bilinear',
                                 align_corners=True)
    elif args.dataset == 'cityscapes':
        val_dataset = data_loader(data_path,
                                  img_size=(512, 1024),
                                  is_transform=True,
                                  split='val')
        valloader = data.DataLoader(val_dataset,
                                    batch_size=args.batch_size,
                                    shuffle=False,
                                    pin_memory=True)
        interp_val = nn.Upsample(size=(512, 1024),
                                 mode='bilinear',
                                 align_corners=True)
    elif args.dataset == 'ade20k':
        val_dataset = ADE20K(mode='val', crop_size=(505, 505))
        valloader = data.DataLoader(val_dataset,
                                    batch_size=args.batch_size,
                                    shuffle=False,
                                    num_workers=4,
                                    pin_memory=True,
                                    drop_last=True)
        interp_val = nn.Upsample(size=(505, 505),
                                 mode='bilinear',
                                 align_corners=True)
    print('val dataloader created!')

    # optimizer for segmentation network
    optimizer = optim.SGD(model.module.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    scheduler = CosineAnnealingLR(optimizer,
                                  T_max=args.num_steps,
                                  eta_min=args.learning_rate /
                                  args.eta_min_factor)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    scheduler_D = CosineAnnealingLR(optimizer_D,
                                    T_max=args.num_steps,
                                    eta_min=args.learning_rate_D /
                                    args.eta_min_factor)
    optimizer_D.zero_grad()

    interp = nn.Upsample(size=(input_size[0], input_size[1]),
                         mode='bilinear',
                         align_corners=True)

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    y_real_, y_fake_ = Variable(torch.ones(args.batch_size,
                                           1).cuda()), Variable(
                                               torch.zeros(args.batch_size,
                                                           1).cuda())

    losses_ce = AverageMeter()
    losses_st = AverageMeter()
    losses_S = AverageMeter()
    losses_D = AverageMeter()
    losses_fm = AverageMeter()
    counts = AverageMeter()

    for i_iter in range(args.num_steps):

        model.train()

        loss_ce_value = 0
        loss_D_value = 0
        loss_fm_value = 0
        loss_S_value = 0

        #args.threshold_st = adjust_threshold_st(i_iter)

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        # train Segmentation Network
        # don't accumulate grads in D
        for param in model_D.parameters():
            param.requires_grad = False

        # training loss for labeled data only
        try:
            batch = next(trainloader_iter)
        except:
            trainloader_iter = iter(trainloader)
            batch = next(trainloader_iter)

        images, labels, _, _, _ = batch
        images = images.cuda()
        pred = interp(model(images))
        loss_ce = loss_calc(pred, labels,
                            args.gpu)  # Cross entropy loss for labeled data

        #training loss for remaining unlabeled data
        try:
            batch_remain = next(trainloader_remain_iter)
        except:
            trainloader_remain_iter = iter(trainloader_remain)
            batch_remain = next(trainloader_remain_iter)

        images_remain, _, _, _, _ = batch_remain
        images_remain = Variable(images_remain).cuda(args.gpu)
        pred_remain = interp(model(images_remain))

        # concatenate the prediction with the input images
        images_remain = (images_remain - torch.min(images_remain)) / (
            torch.max(images_remain) - torch.min(images_remain))
        #print (pred_remain.size(), images_remain.size())
        pred_cat = torch.cat((F.softmax(pred_remain, dim=1), images_remain),
                             dim=1)

        D_out_z, D_out_y_pred = model_D(
            pred_cat)  # predicts the D ouput 0-1 and feature map for FM-loss

        # find predicted segmentation maps above threshold
        pred_sel, labels_sel, count = find_good_maps(D_out_z, pred_remain)

        # training loss on above threshold segmentation predictions (Cross Entropy Loss)
        if count > 0 and i_iter > 0:
            loss_st = loss_calc(pred_sel, labels_sel, args.gpu)
            losses_st.update(loss_st.item())
        else:
            loss_st = 0.0

        # Concatenates the input images and ground-truth maps for the Districrimator 'Real' input
        try:
            batch_gt = next(trainloader_gt_iter)
        except:
            trainloader_gt_iter = iter(trainloader_gt)
            batch_gt = next(trainloader_gt_iter)

        images_gt, labels_gt, _, _, _ = batch_gt
        # Converts grounth truth segmentation into 'num_classes' segmentation maps.
        D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)

        images_gt = images_gt.cuda()
        images_gt = (images_gt - torch.min(images_gt)) / (torch.max(images) -
                                                          torch.min(images))

        D_gt_v_cat = torch.cat((D_gt_v, images_gt), dim=1)
        D_out_z_gt, D_out_y_gt = model_D(D_gt_v_cat)

        # L1 loss for Feature Matching Loss
        loss_fm = torch.mean(
            torch.abs(torch.mean(D_out_y_gt, 0) - torch.mean(D_out_y_pred, 0)))

        if count > 0 and i_iter > 0:  # if any good predictions found for self-training loss
            loss_S = loss_ce + args.lambda_fm * loss_fm + args.lambda_st * loss_st
        else:
            loss_S = loss_ce + args.lambda_fm * loss_fm

        loss_S.backward()
        loss_fm_value += args.lambda_fm * loss_fm

        loss_ce_value += loss_ce.item()
        loss_S_value += loss_S.item()

        # train D
        for param in model_D.parameters():
            param.requires_grad = True

        # train with pred
        pred_cat = pred_cat.detach(
        )  # detach does not allow the graddients to back propagate.

        D_out_z, _ = model_D(pred_cat)
        y_fake_ = Variable(torch.zeros(D_out_z.size(0), 1).cuda())
        loss_D_fake = criterion(D_out_z, y_fake_)

        # train with gt
        D_out_z_gt, _ = model_D(D_gt_v_cat)
        y_real_ = Variable(torch.ones(D_out_z_gt.size(0), 1).cuda())
        loss_D_real = criterion(D_out_z_gt, y_real_)

        loss_D = (loss_D_fake + loss_D_real) / 2.0
        loss_D.backward()
        loss_D_value += loss_D.item()

        optimizer.step()
        #scheduler.step()
        optimizer_D.step()
        #scheduler_D.step()

        losses_ce.update(loss_ce.item())
        losses_S.update(loss_S.item())
        losses_D.update(loss_D.item())
        losses_fm.update(loss_fm.item())
        counts.update(count)

        if i_iter % 10 == 0:
            log_idx = i_iter / 10

            args.writer.add_scalar('train/1.train_loss_ce', losses_ce.avg,
                                   log_idx)
            args.writer.add_scalar('train/2.train_loss_st', losses_st.avg,
                                   log_idx)
            args.writer.add_scalar('train/3.train_loss_fm', losses_fm.avg,
                                   log_idx)
            args.writer.add_scalar('train/4.train_loss_S', losses_S.avg,
                                   log_idx)
            args.writer.add_scalar('train/5.train_loss_D', losses_D.avg,
                                   log_idx)
            args.writer.add_scalar('train/6.count', counts.avg, log_idx)
            args.writer.add_scalar('train/7.lr',
                                   optimizer.param_groups[0]['lr'], log_idx)

            losses_ce = AverageMeter()
            losses_st = AverageMeter()
            losses_S = AverageMeter()
            losses_D = AverageMeter()
            losses_fm = AverageMeter()
            counts = AverageMeter()

            print(
                'iter = {0:8d}/{1:8d}, loss_ce = {2:.3f}, loss_fm = {3:.3f}, loss_S = {4:.3f}, loss_D = {5:.3f}'
                .format(i_iter, args.num_steps, loss_ce_value, loss_fm_value,
                        loss_S_value, loss_D_value))

        if i_iter % 200 == 0:
            miou_val, loss_val = validate(valloader, interp_val, model)
            print('miou_val: ', miou_val, ' loss_val; ', loss_val)
            #mious.update(miou_val)
            #losses_val.update(loss_val)
            args.writer.add_scalar('val/1.val_miou', miou_val, i_iter / 1000)
            args.writer.add_scalar('val/2.val_loss', loss_val, i_iter / 1000)
            #mious = AverageMeter()
            #losses_val = AverageMeter()

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                os.path.join(args.checkpoint_dir,
                             'VOC_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D.state_dict(),
                os.path.join(args.checkpoint_dir,
                             'VOC_' + str(args.num_steps) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('saving checkpoint  ...')
            torch.save(
                model.state_dict(),
                os.path.join(args.checkpoint_dir,
                             'VOC_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                os.path.join(args.checkpoint_dir,
                             'VOC_' + str(i_iter) + '_D.pth'))
    end = timeit.default_timer()
    print(end - start, 'seconds')
コード例 #19
0
ファイル: mnist.py プロジェクト: Lornatang/ClassifierGAN
def train(train_dataloader, model, criterion, optimizer, epoch):
  batch_time = AverageMeter()
  data_time = AverageMeter()
  losses = AverageMeter()
  top1 = AverageMeter()
  top5 = AverageMeter()

  # switch to train mode
  model.train()

  end = time.time()
  for i, data in enumerate(train_dataloader):

    # measure data loading time
    data_time.update(time.time() - end)

    # get the inputs; data is a list of [inputs, labels]
    inputs, targets = data
    inputs = inputs.to(device)
    targets = targets.to(device)

    # compute output
    output = model(inputs)
    loss = criterion(output, targets)

    # measure accuracy and record loss
    prec1, prec5 = accuracy(output, targets, topk=(1, 5))
    losses.update(loss.item(), inputs.size(0))
    top1.update(prec1, inputs.size(0))
    top5.update(prec5, inputs.size(0))

    # compute gradients in a backward pass
    optimizer.zero_grad()
    loss.backward()

    # Call step of optimizer to update model params
    optimizer.step()

    # measure elapsed time
    batch_time.update(time.time() - end)
    end = time.time()

    if i % 5 == 0:
      print(f"Epoch [{epoch + 1}] [{i}/{len(train_dataloader)}]\t"
            f"Time {data_time.val:.3f} ({data_time.avg:.3f})\t"
            f"Loss {loss.item():.4f}\t"
            f"Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t"
            f"Prec@5 {top5.val:.3f} ({top5.avg:.3f})", end="\r")
  torch.save(model.state_dict(), f"./checkpoints/{opt.datasets}_epoch_{epoch + 1}.pth")
コード例 #20
0
def train_with_clustering(save_folder, tmp_seg_folder, startnet, args):
    print(save_folder.split('/')[-1])
    skip_clustering = False

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    check_mkdir(save_folder)
    writer = SummaryWriter(save_folder)
    check_mkdir(tmp_seg_folder)

    # Network and weight loading
    model_config = model_configs.PspnetCityscapesConfig()
    net = model_config.init_network(
        n_classes=args['n_clusters'],
        for_clustering=True,
        output_features=True,
        use_original_base=args['use_original_base']).to(device)

    state_dict = torch.load(startnet)
    if 'resnet101' in startnet:
        load_resnet101_weights(net, state_dict)
    else:
        # needed since we slightly changed the structure of the network in pspnet
        state_dict = rename_keys_to_match(state_dict)
        # different amount of classes
        init_last_layers(state_dict, args['n_clusters'])

        net.load_state_dict(state_dict)  # load original weights

    start_iter = 0
    args['best_record'] = {
        'iter': 0,
        'val_loss_feat': 1e10,
        'val_loss_out': 1e10,
        'val_loss_cluster': 1e10
    }

    # Data loading setup
    if args['corr_set'] == 'rc':
        corr_set_config = data_configs.RobotcarConfig()
    elif args['corr_set'] == 'pola':
        corr_set_config = data_configs.PolaConfig()
    elif args['corr_set'] == 'cmu':
        corr_set_config = data_configs.CmuConfig()
    elif args['corr_set'] == 'both':
        corr_set_config1 = data_configs.CmuConfig()
        corr_set_config2 = data_configs.RobotcarConfig()

    ref_image_lists = corr_set_config.reference_image_list

    # ref_image_lists = glob.glob("/media/HDD1/datasets/Creusot_Jan15/Creusot_3/*.jpg", recursive=True)
    # print(f'ici on print ref image list ---------------------------------------------------- {ref_image_lists}')
    # print(corr_set_config)
    # corr_im_paths = [corr_set_config.correspondence_im_path]
    # ref_featurs_pos = [corr_set_config.reference_feature_poitions]

    input_transform = model_config.input_transform

    #corr_set_train = correspondences.Correspondences(corr_set_config.correspondence_path,
    #                                                 corr_set_config.correspondence_im_path,
    #                                                 input_size=(713, 713),
    #                                                 input_transform=input_transform,
    #                                                 joint_transform=train_joint_transform_corr,
    #                                                 listfile=corr_set_config.correspondence_train_list_file)
    scales = [0, 1, 2, 3]

    # corr_set_train = Poladata.MonoDataset(corr_set_config,
    #                                       seg_folder = "media/HDD1/NsemSEG/Result_fold/" ,
    #                                       im_file_ending = ".jpg" )

    train_joint_transform = joint_transforms.Compose([
        # train_joint_transform_corr = corr_transforms.Compose([
        # corr_transforms.CorrResize(1024),
        # corr_transforms.CorrRandomCrop(713)
        joint_transforms.Resize(1024),
        joint_transforms.RandomCrop(713)
    ])

    sliding_crop = joint_transforms.SlidingCrop(713, 2 / 3., 255)

    # corr_set_train = correspondences.Correspondences(corr_set_config.train_im_folder,
    #                                                 corr_set_config.train_im_folder,
    #                                                 input_size=(713, 713),
    #                                                 input_transform=input_transform,
    #                                                 joint_transform=train_joint_transform,
    #                                                 listfile=None)

    corr_set_train = Poladata.MonoDataset(
        corr_set_config.train_im_folder,
        corr_set_config.train_seg_folder,
        im_file_ending=".jpg",
        id_to_trainid=None,
        joint_transform=train_joint_transform,
        sliding_crop=sliding_crop,
        transform=input_transform,
        target_transform=None,  #train_joint_transform,
        transform_before_sliding=None  #sliding_crop
    )
    #print (corr_set_train)
    # print(corr_set_train.mask)
    corr_loader_train = DataLoader(corr_set_train,
                                   batch_size=1,
                                   num_workers=args['n_workers'],
                                   shuffle=True)
    # corr_loader_train = input_transform(corr_loader_train)

    # print(corr_loader_train)
    seg_loss_fct = torch.nn.CrossEntropyLoss(reduction='elementwise_mean')

    # Optimizer setup
    optimizer = optim.SGD([{
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] == 'bias' and param.requires_grad
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] != 'bias' and param.requires_grad
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'],
                          nesterov=True)

    # Clustering
    deepcluster = clustering.Kmeans(args['n_clusters'])
    if skip_clustering:
        deepcluster.set_index(cluster_centroids)

    open(os.path.join(save_folder,
                      str(datetime.datetime.now()) + '.txt'),
         'w').write(str(args) + '\n\n')

    f_handle = open(os.path.join(save_folder, 'log.log'), 'w', buffering=1)

    # clean_log_before_continuing(os.path.join(save_folder, 'log.log'), start_iter)
    # f_handle = open(os.path.join(save_folder, 'log.log'), 'a', buffering=1)

    val_iter = 0
    curr_iter = start_iter
    while curr_iter <= args['max_iter']:

        net.eval()
        net.output_features = True
        # max_num_features_per_image = args['max_features_per_image']
        # print('-----------------------------------------------------------------')
        # print (f'ref_image_lists est: {ref_image_lists},model_config es : {model_config} , net es: {net} , max feature par image es : {max_num_features_per_image} ')
        # print('-----------------------------------------------------------------')

        # print('le next du loader es : ---------------')
        # print(next(iter(corr_loader_train)))

        # features, _ = extract_features_for_reference(net, model_config, ref_image_lists,
        #                                              corr_im_paths, ref_featurs_pos,
        #                                              max_num_features_per_image=args['max_features_per_image'],
        #                                              fraction_correspondeces=0.5)
        print(
            'ici on a la len de la ref im list --------------------------------------------------------'
        )
        print(len(ref_image_lists))
        features = extract_features_for_reference_nocorr(
            net,
            model_config,
            corr_set_train,
            10,
            max_num_features_per_image=args['max_features_per_image'])

        cluster_features = np.vstack(features)
        del features

        # cluster the features
        cluster_indices, clustering_loss, cluster_centroids, pca_info = deepcluster.cluster_imfeatures(
            cluster_features, verbose=True, use_gpu=False)

        # save cluster centroids
        h5f = h5py.File(
            os.path.join(save_folder, 'centroids_%d.h5' % curr_iter), 'w')
        h5f.create_dataset('cluster_centroids', data=cluster_centroids)
        h5f.create_dataset('pca_transform_Amat', data=pca_info[0])
        h5f.create_dataset('pca_transform_bvec', data=pca_info[1])
        h5f.close()

        # Print distribution of clusters
        cluster_distribution, _ = np.histogram(
            cluster_indices,
            bins=np.arange(args['n_clusters'] + 1),
            density=True)
        str2write = 'cluster distribution ' + \
            np.array2string(cluster_distribution, formatter={
                            'float_kind': '{0:.8f}'.format}).replace('\n', ' ')
        print(str2write)
        f_handle.write(str2write + "\n")

        # set last layer weight to a normal distribution
        reinit_last_layers(net)

        # make a copy of current network state to do cluster assignment
        net_for_clustering = copy.deepcopy(net)

        optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']
        optimizer.param_groups[1]['lr'] = args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']

        net.train()
        freeze_bn(net)
        net.output_features = False
        cluster_training_count = 0

        # Train using the training correspondence set
        corr_train_loss = AverageMeter()
        seg_train_loss = AverageMeter()
        feature_train_loss = AverageMeter()

        while cluster_training_count < args[
                'cluster_interval'] and curr_iter <= args['max_iter']:

            # First extract cluster labels using saved network checkpoint
            print(
                'on rentre dans la boucle extract cluster_______________________________________________'
            )
            net.to("cpu")
            net_for_clustering.to(device)
            net_for_clustering.eval()
            net_for_clustering.output_features = True

            data_samples = []
            extract_label_count = 0
            while (extract_label_count < args['chunk_size']) and (
                    cluster_training_count + extract_label_count <
                    args['cluster_interval']
            ) and (val_iter + extract_label_count < args['val_interval']) and (
                    extract_label_count + curr_iter <= args['max_iter']):
                # img_ref, img_other, pts_ref, pts_other, _ = next(iter(corr_set_train))
                corr_loader_train = input_transform(corr_loader_train)
                print(
                    f'la valeur de corr loader train es de {corr_loader_train} lors de l iteration : {curr_iter}'
                )
                img_ref, img_other, pts_ref, pts_other, _ = next(
                    iter(corr_loader_train))

                # print('le next du loader es : ---------------')
                # print(next(iter(corr_loader_train)))
                # print(img_ref)

                # Transfer data to device
                img_ref = img_ref.to(device)

                with torch.no_grad():
                    features = net_for_clustering(img_ref)

                # assign feature to clusters for entire patch
                output = features.cpu().numpy()
                output_flat = output.reshape(
                    (output.shape[0], output.shape[1], -1))
                cluster_image = np.zeros(
                    (output.shape[0], output.shape[2], output.shape[3]),
                    dtype=np.int64)
                for b in range(output_flat.shape[0]):
                    out_f = output_flat[b]
                    out_f2, _ = preprocess_features(np.swapaxes(out_f, 0, 1),
                                                    pca_info=pca_info)
                    cluster_labels = deepcluster.assign(out_f2)
                    cluster_image[b] = cluster_labels.reshape(
                        (output.shape[2], output.shape[3]))

                cluster_image = torch.from_numpy(cluster_image).to(device)

                # assign cluster to correspondence positions
                cluster_labels = assign_cluster_ids_to_correspondence_points(
                    features,
                    pts_ref, (deepcluster, pca_info),
                    inds_other=pts_other,
                    orig_im_size=(713, 713))

                # Transfer data to cpu
                img_ref = img_ref.cpu()
                cluster_labels = [p.cpu() for p in cluster_labels]
                cluster_image = cluster_image.cpu()
                data_samples.append((img_ref, cluster_labels, cluster_image))
                extract_label_count += 1

            net_for_clustering.to("cpu")
            net.to(device)

            for data_sample in data_samples:
                img_ref, cluster_labels, cluster_image = data_sample

                # Transfer data to device
                img_ref = img_ref.to(device)
                cluster_labels = [p.to(device) for p in cluster_labels]
                cluster_image = cluster_image.to(device)

                optimizer.zero_grad()

                outputs_ref, aux_ref = net(img_ref)

                seg_main_loss = seg_loss_fct(outputs_ref, cluster_image)
                seg_aux_loss = seg_loss_fct(aux_ref, cluster_image)

                loss = args['seg_loss_weight'] * \
                    (seg_main_loss + 0.4 * seg_aux_loss)

                loss.backward()
                optimizer.step()
                cluster_training_count += 1

                if type(seg_main_loss) == torch.Tensor:
                    seg_train_loss.update(seg_main_loss.item(), 1)

                ####################################################################################################
                #       LOGGING ETC
                ####################################################################################################
                curr_iter += 1
                val_iter += 1

                writer.add_scalar('train_seg_loss', seg_train_loss.avg,
                                  curr_iter)
                writer.add_scalar('lr', optimizer.param_groups[1]['lr'],
                                  curr_iter)

                if (curr_iter + 1) % args['print_freq'] == 0:
                    str2write = '[iter %d / %d], [train seg loss %.5f], [train corr loss %.5f], [train feature loss %.5f]. [lr %.10f]' % (
                        curr_iter + 1, args['max_iter'], seg_train_loss.avg,
                        optimizer.param_groups[1]['lr'])

                    print(str2write)
                    f_handle.write(str2write + "\n")

                if curr_iter > args['max_iter']:
                    break

    # Post training
    f_handle.close()
    writer.close()
コード例 #21
0
def train(train_loader, L, D, T, optim_L, optim_D, optim_T, epoch, device, args):

    L.train()
    D.train()
    T.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()

    loss_2ds = AverageMeter()
    loss_3ds = AverageMeter()
    loss_advs = AverageMeter()
    loss_ts = AverageMeter()

    end = time.time()

    L2_loss = nn.MSELoss(reduction='mean')
    BCE_loss = nn.BCELoss(reduction='mean')

    for batch_idx, (xy, X, scale) in enumerate(train_loader):

        data_time.update(time.time() - end)
        # train D
        D.zero_grad()
        batch_sz = xy.size(0)
        pose_2d = xy[:, 0].to(device)  # (bs, 17*2)
        xy_real = xy[:, 1:].to(device) # (bs, length-1,17*2)

        z_pred = L(pose_2d) # (bs,17)

        # random Rotation
        theta = np.random.uniform(-np.pi, np.pi, batch_sz).astype(np.float32)
        cos_theta = np.cos(theta)[:, None]
        sin_theta = np.sin(theta)[:, None]

        cos_theta = torch.from_numpy(cos_theta).to(device)
        sin_theta = torch.from_numpy(sin_theta).to(device)

        x = pose_2d[:, 0::2]
        y = pose_2d[:, 1::2]
        new_x = x*cos_theta + z_pred*sin_theta # create projection

        trans_3d_z = -x*sin_theta + z_pred*cos_theta

        xy_fake = torch.cat((new_x[:,:,None], y[:,:,None]), dim=2)
        xy_fake = xy_fake.view(batch_sz, -1) # (bs, 17*2)

        trans_3d_1 = torch.cat((new_x[:,:,None], y[:,:,None], trans_3d_z[:,:,None]), dim=2)
        trans_3d_1 = trans_3d_1.view(batch_sz, -1) # (bs,17*3)

        
        D_real_score = D( xy_real.view(batch_sz*(args.length-1), -1) ) # (bs*(length-1),17*2)
        y_real_ = torch.ones(batch_sz*(args.length-1), 1).to(device)
        D_real_loss = BCE_loss(D_real_score, y_real_)

        fake_pose_2d_t_repeat = xy_fake.repeat(args.length-1, 1)
        D_fake_score = D(fake_pose_2d_t_repeat)
        y_fake_ = torch.zeros(batch_sz, 1).to(device)
        D_fake_loss = BCE_loss(D_fake_score, y_fake_)

        D_train_loss = D_real_loss + D_fake_loss

        D_train_loss.backward(retain_graph=True)
        optim_D.step()

        # Train T
        T.zero_grad()
        pose_next = xy[:, 1].to(device) # (bs, 17*2)

        # Predict next pose
        z_pred_next = L(pose_next) # (bs,17)
        x_next = pose_next[:, 0::2]
        y_next = pose_next[:, 1::2]
        new_x_next = x_next*cos_theta + z_pred_next*sin_theta
        xy_fake_next = torch.cat((new_x_next[:,:, None], y_next[:,:, None]), dim=2)
        xy_fake_next = xy_fake_next.view(batch_sz,-1) # (bs, 17*2)

        T_real_score = T(pose_2d - pose_next)
        y_real_ = torch.ones(batch_sz, 1).to(device)
        T_real_loss = BCE_loss(T_real_score, y_real_)

        T_fake_score = T(xy_fake - xy_fake_next)
        y_fake_ = torch.zeros(batch_sz, 1).to(device)
        T_fake_loss = BCE_loss(T_fake_score, y_fake_)

        T_train_loss_gp = T_fake_loss + T_real_loss
        T_train_loss_gp.backward(retain_graph=True)
        optim_T.step()
        
        # Train L
        L.zero_grad()
        z_pred_fake_3d = L(xy_fake) # (bs,17)
        
        # Inverse 3D Transformation
        cos_theta_inv = np.cos(-theta)[:, None]
        sin_theta_inv = np.sin(-theta)[:, None]

        cos_theta_inv = torch.from_numpy(cos_theta_inv).to(device)
        sin_theta_inv = torch.from_numpy(sin_theta_inv).to(device)

        x_fake = xy_fake[:, 0::2]
        y_fake = xy_fake[:, 1::2]
        recover_new_x = x_fake*cos_theta_inv + z_pred_fake_3d*sin_theta_inv
        recover_xy = torch.cat((recover_new_x[:,:,None], y_fake[:,:,None]), dim=2)
        recover_xy = recover_xy.view(batch_sz, -1) # (bs, 17*2)

        trans_3d_2 = torch.cat((x_fake[:,:,None], y_fake[:,:,None], z_pred_fake_3d[:,:,None]), dim=2)
        trans_3d_2 = trans_3d_2.view(batch_sz, -1)
        
        loss_2d = L2_loss(recover_xy, pose_2d)
        loss_3d = L2_loss(trans_3d_1, trans_3d_2)

        D_result = D(xy_fake)
        y_ = torch.ones(batch_sz, 1).to(device)
        loss_adv = BCE_loss(D_result, y_)

        T_result = T(xy_fake - xy_fake_next)
        loss_t = BCE_loss(T_result, y_)

        L_train_loss = loss_adv + args.weight_2d*loss_2d + args.weight_3d*loss_3d + args.weight_wt*loss_t

        L_train_loss.backward(retain_graph=True)
        optim_L.step()
        
        loss_2ds.update(loss_2d.item(), batch_sz)
        loss_3ds.update(loss_3d.item(), batch_sz)
        loss_advs.update(loss_adv.item(), batch_sz)
        loss_ts.update(loss_t.item(), batch_sz)

        batch_time.update(time.time() - end)
        end = time.time()

        if args.verbose:
            if batch_idx % 5 == 0:
                outstr = '[{batch}/{size}], Data: {data:.3f}s | Batch: {bt:.3f}s | loss_2d: {l2d:.6f} | loss_3d: {l3d:.6f} | loss_adv: {ladv:.6f} | loss_t: {lt:.6f}'.format(
                    batch = batch_idx + 1,
                    size = len(train_loader),
                    data = data_time.val,
                    bt = batch_time.val,
                    l2d = loss_2ds.val,
                    l3d = loss_3ds.val,
                    ladv = loss_advs.val,
                    lt = loss_ts.val,
                )
                print(outstr)

        return loss_2ds.avg, loss_3ds.avg, loss_advs.avg, loss_ts.avg
コード例 #22
0
def train(train_loader, L, D, T, optim_L, optim_D, optim_T, epoch, device,
          args):

    L.train()
    D.train()
    T.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()

    loss_2ds = AverageMeter()
    loss_3ds = AverageMeter()
    loss_advs = AverageMeter()
    loss_ts = AverageMeter()

    end = time.time()

    L2_loss = nn.MSELoss(reduction='mean')
    BCE_loss = nn.BCELoss(reduction='mean')

    for batch_idx, (xy, X, ls) in enumerate(train_loader):

        data_time.update(time.time() - end)

        # train D
        D.zero_grad()
        batch_sz = xy.size(0)
        pose_2d = xy[:, 0].to(device)  # (bs, 17*2)
        xy_real = xy[:, 1:].to(device)  # (bs, length-1,17*2)

        fake_pose_2d_t, _, trans_3d_1, rot_mat = lift_proj(
            L, pose_2d, batch_sz, device, None, True)

        D_real_loss = D(xy_real.view(batch_sz * (args.length - 1),
                                     -1)).mean()  # (bs*(length-1),17*2)
        #y_real_ = torch.ones(batch_sz*(args.length-1), 1).to(device)
        #D_real_loss = BCE_loss(D_real_score, y_real_)

        fake_pose_2d_t_repeat = fake_pose_2d_t.repeat(args.length - 1, 1)
        D_fake_loss = D(fake_pose_2d_t_repeat).mean()
        #y_fake_ = torch.zeros(batch_sz, 1).to(device)
        #D_fake_loss = BCE_loss(D_fake_score, y_fake_)

        gradient_penalty_D = compute_gradient_penalty(
            D, xy_real.view(batch_sz * (args.length - 1), -1),
            fake_pose_2d_t_repeat, args.lamda_gp, device)

        D_train_loss = D_fake_loss - D_real_loss + gradient_penalty_D
        D_train_loss.backward(retain_graph=True)
        optim_D.step()

        # Train T
        T.zero_grad()
        xy_pose_next = xy[:, 1].to(device)

        fake_pose_2d_next_t, _, _, _ = lift_proj(L, xy_pose_next, batch_sz,
                                                 device, rot_mat, True)

        T_real_loss = T(pose_2d - xy_pose_next).mean()
        #y_real_ = torch.ones(batch_sz, 1).to(device)
        #T_real_loss = BCE_loss(T_real_score, y_real_)

        T_fake_loss = T(fake_pose_2d_t - fake_pose_2d_next_t).mean()
        #y_fake_ = torch.zeros(batch_sz, 1).to(device)
        #T_fake_loss = BCE_loss(T_fake_score, y_fake_)

        gradient_penalty_T = compute_gradient_penalty(
            T, pose_2d - xy_pose_next, fake_pose_2d_t - fake_pose_2d_next_t,
            args.lamda_gp, device)

        T_train_loss = T_fake_loss - T_real_loss + gradient_penalty_T
        T_train_loss.backward(retain_graph=True)
        optim_T.step()

        # Train L
        L.zero_grad()
        rot_mat_inv = rot_mat.inverse()
        recon_pose_2d_t, trans_3d_2, _, _ = lift_proj(L, fake_pose_2d_t,
                                                      batch_sz, device,
                                                      rot_mat_inv, False)

        # print(trans_3d_1[0], trans_3d_2[0])
        loss_2d = L2_loss(recon_pose_2d_t, pose_2d)
        loss_3d = L2_loss(trans_3d_1, trans_3d_2)

        #D_result = D(fake_pose_2d_t)
        #y_ = torch.ones(batch_sz, 1).to(device)
        #loss_adv = BCE_loss(D_result, y_)

        loss_adv = -D(fake_pose_2d_t).mean()

        #T_result = T(fake_pose_2d_t - fake_pose_2d_next_t)
        #loss_t = BCE_loss(T_result, y_)

        loss_t = -T(fake_pose_2d_t - fake_pose_2d_next_t).mean()

        L_train_loss = loss_adv + args.weight_2d * loss_2d + args.weight_3d * loss_3d + args.weight_wt * loss_t
        L_train_loss.backward(retain_graph=True)
        optim_L.step()

        loss_2ds.update(loss_2d.item(), batch_sz)
        loss_3ds.update(loss_3d.item(), batch_sz)
        loss_advs.update(loss_adv.item(), batch_sz)
        loss_ts.update(loss_t.item(), batch_sz)

        batch_time.update(time.time() - end)
        end = time.time()

        if args.verbose:
            if batch_idx % 5 == 0:
                outstr = '[{batch}/{size}], Data: {data:.3f}s | Batch: {bt:.3f}s | loss_2d: {l2d:.6f} | loss_3d: {l3d:.6f} | loss_adv: {ladv:.6f} | loss_t: {lt:.6f}'.format(
                    batch=batch_idx + 1,
                    size=len(train_loader),
                    data=data_time.val,
                    bt=batch_time.val,
                    l2d=loss_2ds.val,
                    l3d=loss_3ds.val,
                    ladv=loss_advs.val,
                    lt=loss_ts.val,
                )
                print(outstr)

        return loss_2ds.avg, loss_3ds.avg, loss_advs.avg, loss_ts.avg
コード例 #23
0
def train_cae(trainloader, model, class_name, testloader, y_train, device,
              args):
    """
    model train function.
    :param trainloader:
    :param model:
    :param class_name:
    :param testloader:
    :param y_train: numpy array, sample normal/abnormal labels, [1 1 1 1 0 0] like, original sample size.
    :param device: cpu or gpu:0/1/...
    :param args:
    :return:
    """
    global_step = 0
    losses = AverageMeter()
    start_time = time.time()
    epoch_time = AverageMeter()

    for epoch in range(1, args.epochs + 1):
        model.train()

        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)
        print('{:3d}/{:3d} ----- {:s} {:s}'.format(epoch, args.epochs,
                                                   time_string(), need_time))

        mse = nn.MSELoss(reduction='mean')  # default

        lr = 0.1 / pow(2, np.floor(epoch / args.lr_schedule))
        logger.add_scalar(class_name + "/lr", lr, epoch)

        if args.optimizer == 'SGD':
            optimizer = optim.SGD(model.parameters(),
                                  lr=lr,
                                  weight_decay=args.weight_decay)
        else:
            optimizer = optim.Adam(model.parameters(),
                                   eps=1e-7,
                                   weight_decay=0.0005)
        for batch_idx, (input, _, _) in enumerate(trainloader):
            optimizer.zero_grad()
            input = input.to(device)

            _, output = model(input)
            loss = mse(input, output)
            losses.update(loss.item(), 1)

            logger.add_scalar(class_name + '/loss', losses.avg, global_step)

            global_step = global_step + 1
            loss.backward()
            optimizer.step()

        # print losses
        print('Epoch: [{} | {}], loss: {:.4f}'.format(epoch, args.epochs,
                                                      losses.avg))

        # log images
        if epoch % args.log_img_steps == 0:
            os.makedirs(os.path.join(RESULTS_DIR, class_name), exist_ok=True)
            fpath = os.path.join(RESULTS_DIR, class_name,
                                 'pretrain_epoch_' + str(epoch) + '.png')
            visualize(input, output, fpath, num=32)

        # test while training
        if epoch % args.log_auc_steps == 0:
            rep, losses_result = test(testloader, model, class_name, args,
                                      device, epoch)

            centroid = torch.mean(rep, dim=0, keepdim=True)

            losses_result = losses_result - losses_result.min()
            losses_result = losses_result / (1e-8 + losses_result.max())
            scores = 1 - losses_result
            auroc_rec = roc_auc_score(y_train, scores)

            _, p = dec_loss_fun(rep, centroid)
            score_p = p[:, 0]
            auroc_dec = roc_auc_score(y_train, score_p)

            print("Epoch: [{} | {}], auroc_rec: {:.4f}; auroc_dec: {:.4f}".
                  format(epoch, args.epochs, auroc_rec, auroc_dec))

            logger.add_scalar(class_name + '/auroc_rec', auroc_rec, epoch)
            logger.add_scalar(class_name + '/auroc_dec', auroc_dec, epoch)

        # time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
コード例 #24
0
ファイル: train.py プロジェクト: alwc/pse-lite.pytorch
def train(train_loader, model, criterion, optimizer, args):
    model.train()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    running_metric_text = runningScore(2)
    running_metric_kernel = runningScore(2)

    end = time.time()
    for batch_idx, (imgs, gt_texts, gt_kernels,
                    training_masks) in enumerate(train_loader):
        data_time.update(time.time() - end)

        imgs = Variable(imgs.cuda())
        gt_texts = Variable(gt_texts.cuda())
        gt_kernels = Variable(gt_kernels.cuda())
        training_masks = Variable(training_masks.cuda())

        outputs = model(imgs)
        texts = outputs[:, 0, :, :]
        kernels = outputs[:, 1:, :, :]

        loss = criterion(texts, gt_texts, kernels, gt_kernels, training_masks)
        losses.update(loss.item(), imgs.size(0))

        optimizer.zero_grad()
        loss.backward()

        if (args.sr_lr is not None):
            updateBN(model, args)

        optimizer.step()

        score_text = cal_text_score(texts, gt_texts, training_masks,
                                    running_metric_text)
        score_kernel = cal_kernel_score(kernels, gt_kernels, gt_texts,
                                        training_masks, running_metric_kernel)

        batch_time.update(time.time() - end)
        end = time.time()

        if batch_idx % 20 == 0:
            output_log = '({batch}/{size}) Batch: {bt:.3f}s | TOTAL: {total:.0f}min | ETA: {eta:.0f}min | Loss: {loss:.4f} | Acc_t: {acc: .4f} | IOU_t: {iou_t: .4f} | IOU_k: {iou_k: .4f}'.format(
                batch=batch_idx + 1,
                size=len(train_loader),
                bt=batch_time.avg,
                total=batch_time.avg * batch_idx / 60.0,
                eta=batch_time.avg * (len(train_loader) - batch_idx) / 60.0,
                loss=losses.avg,
                acc=score_text['Mean Acc'],
                iou_t=score_text['Mean IoU'],
                iou_k=score_kernel['Mean IoU'])
            print(output_log)
            sys.stdout.flush()

    return (losses.avg, score_text['Mean Acc'], score_kernel['Mean Acc'],
            score_text['Mean IoU'], score_kernel['Mean IoU'])
コード例 #25
0
def evaluate(val_loader, model, criterion, test=None):
    '''
    模型评估
    :param val_loader:
    :param model:
    :param criterion:
    :param test:
    :return:
    '''

    global best_acc

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    predict_all = np.array([], dtype=int)
    labels_all = np.array([], dtype=int)

    #################
    # val the model
    #################
    model.eval()
    end = time.time()

    # 训练每批数据,然后进行模型的训练
    ## 定义bar 变量
    bar = Bar('Processing', max=len(val_loader))
    for batch_index, (inputs, targets) in enumerate(val_loader):
        data_time.update(time.time() - end)
        # move tensors to GPU if cuda is_available
        inputs, targets = inputs.to(device), targets.to(device)
        # 模型的预测
        outputs = model(inputs)
        # 计算loss
        loss = criterion(outputs, targets)

        # 计算acc和变量更新
        prec1, _ = accuracy(outputs.data, targets.data, topk=(1, 1))
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        batch_time.update(time.time() - end)
        end = time.time()

        # 评估混淆矩阵的数据
        targets = targets.data.cpu().numpy()  # 真实数据的y数值
        predic = torch.max(outputs.data, 1)[1].cpu().numpy()  # 预测数据y数值
        labels_all = np.append(labels_all, targets)  # 数据赋值
        predict_all = np.append(predict_all, predic)

        ## 把主要的参数打包放进bar中
        # plot progress
        bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f}'.format(
            batch=batch_index + 1,
            size=len(val_loader),
            data=data_time.val,
            bt=batch_time.val,
            total=bar.elapsed_td,
            eta=bar.eta_td,
            loss=losses.avg,
            top1=top1.avg)
        bar.next()
    bar.finish()

    if test:
        return (losses.avg, top1.avg, predict_all, labels_all)
    else:
        return (losses.avg, top1.avg)
コード例 #26
0
def train_with_correspondences(save_folder, startnet, args):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    check_mkdir(save_folder)
    writer = SummaryWriter(save_folder)

    # Network and weight loading
    model_config = model_configs.PspnetCityscapesConfig()
    net = model_config.init_network().to(device)

    if args['snapshot'] == 'latest':
        args['snapshot'] = get_latest_network_name(save_folder)

    if len(args['snapshot']) == 0:  # If start from beginning
        state_dict = torch.load(startnet)
        # needed since we slightly changed the structure of the network in
        # pspnet
        state_dict = rename_keys_to_match(state_dict)
        net.load_state_dict(state_dict)  # load original weights

        start_iter = 0
        args['best_record'] = {
            'iter': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:  # If continue training
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(save_folder,
                                    args['snapshot'])))  # load weights
        split_snapshot = args['snapshot'].split('_')

        start_iter = int(split_snapshot[1])
        with open(os.path.join(save_folder, 'bestval.txt')) as f:
            best_val_dict_str = f.read()
        args['best_record'] = eval(best_val_dict_str.rstrip())

    net.train()
    freeze_bn(net)

    # Data loading setup
    if args['corr_set'] == 'rc':
        corr_set_config = data_configs.RobotcarConfig()
    elif args['corr_set'] == 'cmu':
        corr_set_config = data_configs.CmuConfig()

    sliding_crop_im = joint_transforms.SlidingCropImageOnly(
        713, args['stride_rate'])

    input_transform = model_config.input_transform
    pre_validation_transform = model_config.pre_validation_transform

    target_transform = extended_transforms.MaskToTensor()

    train_joint_transform_seg = joint_transforms.Compose([
        joint_transforms.Resize(1024),
        joint_transforms.RandomRotate(10),
        joint_transforms.RandomHorizontallyFlip(),
        joint_transforms.RandomCrop(713)
    ])

    train_joint_transform_corr = corr_transforms.Compose([
        corr_transforms.CorrResize(1024),
        corr_transforms.CorrRandomCrop(713)
    ])

    # keep list of segmentation loaders and validators
    seg_loaders = list()
    validators = list()

    # Correspondences
    corr_set = correspondences.Correspondences(
        corr_set_config.correspondence_path,
        corr_set_config.correspondence_im_path,
        input_size=(713, 713),
        mean_std=model_config.mean_std,
        input_transform=input_transform,
        joint_transform=train_joint_transform_corr)
    corr_loader = DataLoader(corr_set,
                             batch_size=args['train_batch_size'],
                             num_workers=args['n_workers'],
                             shuffle=True)

    # Cityscapes Training
    c_config = data_configs.CityscapesConfig()
    seg_set_cs = cityscapes.CityScapes(
        c_config.train_im_folder,
        c_config.train_seg_folder,
        c_config.im_file_ending,
        c_config.seg_file_ending,
        id_to_trainid=c_config.id_to_trainid,
        joint_transform=train_joint_transform_seg,
        sliding_crop=None,
        transform=input_transform,
        target_transform=target_transform)
    seg_loader_cs = DataLoader(seg_set_cs,
                               batch_size=args['train_batch_size'],
                               num_workers=args['n_workers'],
                               shuffle=True)
    seg_loaders.append(seg_loader_cs)

    # Cityscapes Validation
    val_set_cs = cityscapes.CityScapes(
        c_config.val_im_folder,
        c_config.val_seg_folder,
        c_config.im_file_ending,
        c_config.seg_file_ending,
        id_to_trainid=c_config.id_to_trainid,
        sliding_crop=sliding_crop_im,
        transform=input_transform,
        target_transform=target_transform,
        transform_before_sliding=pre_validation_transform)
    val_loader_cs = DataLoader(val_set_cs,
                               batch_size=1,
                               num_workers=args['n_workers'],
                               shuffle=False)
    validator_cs = Validator(val_loader_cs,
                             n_classes=c_config.n_classes,
                             save_snapshot=False,
                             extra_name_str='Cityscapes')
    validators.append(validator_cs)

    # Vistas Training and Validation
    if args['include_vistas']:
        v_config = data_configs.VistasConfig(
            use_subsampled_validation_set=True, use_cityscapes_classes=True)

        seg_set_vis = cityscapes.CityScapes(
            v_config.train_im_folder,
            v_config.train_seg_folder,
            v_config.im_file_ending,
            v_config.seg_file_ending,
            id_to_trainid=v_config.id_to_trainid,
            joint_transform=train_joint_transform_seg,
            sliding_crop=None,
            transform=input_transform,
            target_transform=target_transform)
        seg_loader_vis = DataLoader(seg_set_vis,
                                    batch_size=args['train_batch_size'],
                                    num_workers=args['n_workers'],
                                    shuffle=True)
        seg_loaders.append(seg_loader_vis)

        val_set_vis = cityscapes.CityScapes(
            v_config.val_im_folder,
            v_config.val_seg_folder,
            v_config.im_file_ending,
            v_config.seg_file_ending,
            id_to_trainid=v_config.id_to_trainid,
            sliding_crop=sliding_crop_im,
            transform=input_transform,
            target_transform=target_transform,
            transform_before_sliding=pre_validation_transform)
        val_loader_vis = DataLoader(val_set_vis,
                                    batch_size=1,
                                    num_workers=args['n_workers'],
                                    shuffle=False)
        validator_vis = Validator(val_loader_vis,
                                  n_classes=v_config.n_classes,
                                  save_snapshot=False,
                                  extra_name_str='Vistas')
        validators.append(validator_vis)
    else:
        seg_loader_vis = None
        map_validator = None

    # Extra Training
    extra_seg_set = cityscapes.CityScapes(
        corr_set_config.train_im_folder,
        corr_set_config.train_seg_folder,
        corr_set_config.im_file_ending,
        corr_set_config.seg_file_ending,
        id_to_trainid=corr_set_config.id_to_trainid,
        joint_transform=train_joint_transform_seg,
        sliding_crop=None,
        transform=input_transform,
        target_transform=target_transform)
    extra_seg_loader = DataLoader(extra_seg_set,
                                  batch_size=args['train_batch_size'],
                                  num_workers=args['n_workers'],
                                  shuffle=True)
    seg_loaders.append(extra_seg_loader)

    # Extra Validation
    extra_val_set = cityscapes.CityScapes(
        corr_set_config.val_im_folder,
        corr_set_config.val_seg_folder,
        corr_set_config.im_file_ending,
        corr_set_config.seg_file_ending,
        id_to_trainid=corr_set_config.id_to_trainid,
        sliding_crop=sliding_crop_im,
        transform=input_transform,
        target_transform=target_transform,
        transform_before_sliding=pre_validation_transform)
    extra_val_loader = DataLoader(extra_val_set,
                                  batch_size=1,
                                  num_workers=args['n_workers'],
                                  shuffle=False)
    extra_validator = Validator(extra_val_loader,
                                n_classes=corr_set_config.n_classes,
                                save_snapshot=True,
                                extra_name_str='Extra')
    validators.append(extra_validator)

    # Loss setup
    if args['corr_loss_type'] == 'class':
        corr_loss_fct = CorrClassLoss(input_size=[713, 713])
    else:
        corr_loss_fct = FeatureLoss(
            input_size=[713, 713],
            loss_type=args['corr_loss_type'],
            feat_dist_threshold_match=args['feat_dist_threshold_match'],
            feat_dist_threshold_nomatch=args['feat_dist_threshold_nomatch'],
            n_not_matching=0)

    seg_loss_fct = torch.nn.CrossEntropyLoss(
        reduction='elementwise_mean',
        ignore_index=cityscapes.ignore_label).to(device)

    # Optimizer setup
    optimizer = optim.SGD([{
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] == 'bias' and param.requires_grad
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] != 'bias' and param.requires_grad
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'],
                          nesterov=True)

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(os.path.join(save_folder, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    open(os.path.join(save_folder,
                      str(datetime.datetime.now()) + '.txt'),
         'w').write(str(args) + '\n\n')

    if len(args['snapshot']) == 0:
        f_handle = open(os.path.join(save_folder, 'log.log'), 'w', buffering=1)
    else:
        clean_log_before_continuing(os.path.join(save_folder, 'log.log'),
                                    start_iter)
        f_handle = open(os.path.join(save_folder, 'log.log'), 'a', buffering=1)

    ##########################################################################
    #
    #       MAIN TRAINING CONSISTS OF ALL SEGMENTATION LOSSES AND A CORRESPONDENCE LOSS
    #
    ##########################################################################
    softm = torch.nn.Softmax2d()

    val_iter = 0
    train_corr_loss = AverageMeter()
    train_seg_cs_loss = AverageMeter()
    train_seg_extra_loss = AverageMeter()
    train_seg_vis_loss = AverageMeter()

    seg_loss_meters = list()
    seg_loss_meters.append(train_seg_cs_loss)
    if args['include_vistas']:
        seg_loss_meters.append(train_seg_vis_loss)
    seg_loss_meters.append(train_seg_extra_loss)

    curr_iter = start_iter

    for i in range(args['max_iter']):
        optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']
        optimizer.param_groups[1]['lr'] = args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']

        #######################################################################
        #       SEGMENTATION UPDATE STEP
        #######################################################################
        #
        for si, seg_loader in enumerate(seg_loaders):
            # get segmentation training sample
            inputs, gts = next(iter(seg_loader))

            slice_batch_pixel_size = inputs.size(0) * inputs.size(
                2) * inputs.size(3)

            inputs = inputs.to(device)
            gts = gts.to(device)

            optimizer.zero_grad()
            outputs, aux = net(inputs)

            main_loss = args['seg_loss_weight'] * seg_loss_fct(outputs, gts)
            aux_loss = args['seg_loss_weight'] * seg_loss_fct(aux, gts)
            loss = main_loss + 0.4 * aux_loss

            loss.backward()
            optimizer.step()

            seg_loss_meters[si].update(main_loss.item(),
                                       slice_batch_pixel_size)

        #######################################################################
        #       CORRESPONDENCE UPDATE STEP
        #######################################################################
        if args['corr_loss_weight'] > 0 and args[
                'n_iterations_before_corr_loss'] < curr_iter:
            img_ref, img_other, pts_ref, pts_other, weights = next(
                iter(corr_loader))

            # Transfer data to device
            # img_ref is from the "good" sequence with generally better
            # segmentation results
            img_ref = img_ref.to(device)
            img_other = img_other.to(device)
            pts_ref = [p.to(device) for p in pts_ref]
            pts_other = [p.to(device) for p in pts_other]
            weights = [w.to(device) for w in weights]

            # Forward pass
            if args['corr_loss_type'] == 'hingeF':  # Works on features
                net.output_all = True
                with torch.no_grad():
                    output_feat_ref, aux_feat_ref, output_ref, aux_ref = net(
                        img_ref)
                output_feat_other, aux_feat_other, output_other, aux_other = net(
                    img_other
                )  # output1 must be last to backpropagate derivative correctly
                net.output_all = False

            else:  # Works on class probs
                with torch.no_grad():
                    output_ref, aux_ref = net(img_ref)
                    if args['corr_loss_type'] != 'hingeF' and args[
                            'corr_loss_type'] != 'hingeC':
                        output_ref = softm(output_ref)
                        aux_ref = softm(aux_ref)

                # output1 must be last to backpropagate derivative correctly
                output_other, aux_other = net(img_other)
                if args['corr_loss_type'] != 'hingeF' and args[
                        'corr_loss_type'] != 'hingeC':
                    output_other = softm(output_other)
                    aux_other = softm(aux_other)

            # Correspondence filtering
            pts_ref_orig, pts_other_orig, weights_orig, batch_inds_to_keep_orig = correspondences.refine_correspondence_sample(
                output_ref,
                output_other,
                pts_ref,
                pts_other,
                weights,
                remove_same_class=args['remove_same_class'],
                remove_classes=args['classes_to_ignore'])
            pts_ref_orig = [
                p for b, p in zip(batch_inds_to_keep_orig, pts_ref_orig)
                if b.item() > 0
            ]
            pts_other_orig = [
                p for b, p in zip(batch_inds_to_keep_orig, pts_other_orig)
                if b.item() > 0
            ]
            weights_orig = [
                p for b, p in zip(batch_inds_to_keep_orig, weights_orig)
                if b.item() > 0
            ]
            if args['corr_loss_type'] == 'hingeF':
                # remove entire samples if needed
                output_vals_ref = output_feat_ref[batch_inds_to_keep_orig]
                output_vals_other = output_feat_other[batch_inds_to_keep_orig]
            else:
                # remove entire samples if needed
                output_vals_ref = output_ref[batch_inds_to_keep_orig]
                output_vals_other = output_other[batch_inds_to_keep_orig]

            pts_ref_aux, pts_other_aux, weights_aux, batch_inds_to_keep_aux = correspondences.refine_correspondence_sample(
                aux_ref,
                aux_other,
                pts_ref,
                pts_other,
                weights,
                remove_same_class=args['remove_same_class'],
                remove_classes=args['classes_to_ignore'])
            pts_ref_aux = [
                p for b, p in zip(batch_inds_to_keep_aux, pts_ref_aux)
                if b.item() > 0
            ]
            pts_other_aux = [
                p for b, p in zip(batch_inds_to_keep_aux, pts_other_aux)
                if b.item() > 0
            ]
            weights_aux = [
                p for b, p in zip(batch_inds_to_keep_aux, weights_aux)
                if b.item() > 0
            ]
            if args['corr_loss_type'] == 'hingeF':
                # remove entire samples if needed
                aux_vals_ref = aux_feat_ref[batch_inds_to_keep_orig]
                aux_vals_other = aux_feat_other[batch_inds_to_keep_orig]
            else:
                # remove entire samples if needed
                aux_vals_ref = aux_ref[batch_inds_to_keep_aux]
                aux_vals_other = aux_other[batch_inds_to_keep_aux]

            optimizer.zero_grad()

            # correspondence loss
            if output_vals_ref.size(0) > 0:
                loss_corr_hr = corr_loss_fct(output_vals_ref,
                                             output_vals_other, pts_ref_orig,
                                             pts_other_orig, weights_orig)
            else:
                loss_corr_hr = 0 * output_vals_other.sum()

            if aux_vals_ref.size(0) > 0:
                loss_corr_aux = corr_loss_fct(
                    aux_vals_ref, aux_vals_other, pts_ref_aux, pts_other_aux,
                    weights_aux)  # use output from img1 as "reference"
            else:
                loss_corr_aux = 0 * aux_vals_other.sum()

            loss_corr = args['corr_loss_weight'] * \
                (loss_corr_hr + 0.4 * loss_corr_aux)
            loss_corr.backward()

            optimizer.step()
            train_corr_loss.update(loss_corr.item())

        #######################################################################
        #       LOGGING ETC
        #######################################################################
        curr_iter += 1
        val_iter += 1

        writer.add_scalar('train_seg_loss_cs', train_seg_cs_loss.avg,
                          curr_iter)
        writer.add_scalar('train_seg_loss_extra', train_seg_extra_loss.avg,
                          curr_iter)
        writer.add_scalar('train_seg_loss_vis', train_seg_vis_loss.avg,
                          curr_iter)
        writer.add_scalar('train_corr_loss', train_corr_loss.avg, curr_iter)
        writer.add_scalar('lr', optimizer.param_groups[1]['lr'], curr_iter)

        if (i + 1) % args['print_freq'] == 0:
            str2write = '[iter %d / %d], [train corr loss %.5f] , [seg cs loss %.5f], [seg vis loss %.5f], [seg extra loss %.5f]. [lr %.10f]' % (
                curr_iter, len(corr_loader), train_corr_loss.avg,
                train_seg_cs_loss.avg, train_seg_vis_loss.avg,
                train_seg_extra_loss.avg, optimizer.param_groups[1]['lr'])
            print(str2write)
            f_handle.write(str2write + "\n")

        if val_iter >= args['val_interval']:
            val_iter = 0
            for validator in validators:
                validator.run(net,
                              optimizer,
                              args,
                              curr_iter,
                              save_folder,
                              f_handle,
                              writer=writer)

    # Post training
    f_handle.close()
    writer.close()
コード例 #27
0
def train(train_loader, net, optim, curr_epoch):
    """
    Runs the training loop per epoch
    train_loader: Data loader for train
    net: thet network
    optimizer: optimizer
    curr_epoch: current epoch
    return:
    """
    net.train()

    train_main_loss = AverageMeter()
    start_time = None
    warmup_iter = 10

    for i, data in enumerate(train_loader):
        if i <= warmup_iter:
            start_time = time.time()
        # inputs = (bs,3,713,713)
        # gts    = (bs,713,713)
        images, gts, _img_name, scale_float = data
        batch_pixel_size = images.size(0) * images.size(2) * images.size(3)
        images, gts, scale_float = images.cuda(), gts.cuda(), scale_float.cuda()
        inputs = {'images': images, 'gts': gts}

        optim.zero_grad()
        main_loss = net(inputs)

        if args.apex:
            log_main_loss = main_loss.clone().detach_()
            torch.distributed.all_reduce(log_main_loss,
                                         torch.distributed.ReduceOp.SUM)
            log_main_loss = log_main_loss / args.world_size
        else:
            main_loss = main_loss.mean()
            log_main_loss = main_loss.clone().detach_()

        train_main_loss.update(log_main_loss.item(), batch_pixel_size)
        if args.fp16:
            with amp.scale_loss(main_loss, optim) as scaled_loss:
                scaled_loss.backward()
        else:
            main_loss.backward()

        optim.step()

        if i >= warmup_iter:
            curr_time = time.time()
            batches = i - warmup_iter + 1
            batchtime = (curr_time - start_time) / batches
        else:
            batchtime = 0

        msg = ('[epoch {}], [iter {} / {}], [train main loss {:0.6f}],'
               ' [lr {:0.6f}] [batchtime {:0.3g}]')
        msg = msg.format(
            curr_epoch, i + 1, len(train_loader), train_main_loss.avg,
            optim.param_groups[-1]['lr'], batchtime)
        logx.msg(msg)

        metrics = {'loss': train_main_loss.avg,
                   'lr': optim.param_groups[-1]['lr']}
        curr_iter = curr_epoch * len(train_loader) + i
        logx.metric('train', metrics, curr_iter)

        if i >= 10 and args.test_mode:
            del data, inputs, gts
            return
        del data
コード例 #28
0
def test(test_loader, model, configs):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    acc_event = AverageMeter('Acc_Event', ':6.4f')
    iou_seg = AverageMeter('IoU_Seg', ':6.4f')
    mse_global = AverageMeter('MSE_Global', ':6.4f')
    mse_local = AverageMeter('MSE_Local', ':6.4f')

    # switch to evaluate mode
    model.eval()
    with torch.no_grad():
        start_time = time.time()
        for batch_idx, (origin_imgs, resized_imgs, org_ball_pos_xy,
                        global_ball_pos_xy, event_class,
                        target_seg) in enumerate(tqdm(test_loader)):
            data_time.update(time.time() - start_time)
            batch_size = resized_imgs.size(0)
            target_seg = target_seg.to(configs.device, non_blocking=True)
            resized_imgs = resized_imgs.to(configs.device,
                                           non_blocking=True).float()
            # compute output
            if 'local' in configs.tasks:
                origin_imgs = origin_imgs.to(configs.device,
                                             non_blocking=True).float()
                pred_ball_global, pred_ball_local, pred_events, pred_seg, local_ball_pos_xy, total_loss, _ = model(
                    origin_imgs, resized_imgs, org_ball_pos_xy,
                    global_ball_pos_xy, event_class, target_seg)
            else:
                pred_ball_global, pred_ball_local, pred_events, pred_seg, local_ball_pos_xy, total_loss, _ = model(
                    None, resized_imgs, org_ball_pos_xy, global_ball_pos_xy,
                    event_class, target_seg)
            # Transfer output to cpu
            pred_ball_global = pred_ball_global.cpu().numpy()
            global_ball_pos_xy = global_ball_pos_xy.numpy()
            if pred_ball_local is not None:
                pred_ball_local = pred_ball_local.cpu().numpy()
                local_ball_pos_xy = local_ball_pos_xy.cpu().numpy(
                )  # Ground truth of the local stage
            if pred_events is not None:
                pred_events = pred_events.cpu().numpy()
            if pred_seg is not None:
                pred_seg = pred_seg.cpu().numpy()
                target_seg = target_seg.cpu().numpy()

            org_ball_pos_xy = org_ball_pos_xy.numpy()

            for sample_idx in range(batch_size):
                w, h = configs.input_size
                # Get target
                sample_org_ball_pos_xy = org_ball_pos_xy[sample_idx]
                sample_global_ball_pos_xy = global_ball_pos_xy[
                    sample_idx]  # Target
                # Process the global stage
                sample_pred_ball_global = pred_ball_global[sample_idx]
                sample_pred_ball_global[sample_pred_ball_global <
                                        configs.thresh_ball_pos_mask] = 0.
                sample_pred_ball_global_x = np.argmax(
                    sample_pred_ball_global[:w])
                sample_pred_ball_global_y = np.argmax(
                    sample_pred_ball_global[w:])

                # Calculate the MSE
                if (sample_global_ball_pos_xy[0] >
                        0) and (sample_global_ball_pos_xy[1] > 0):
                    mse = (sample_pred_ball_global_x -
                           sample_global_ball_pos_xy[0])**2 + (
                               sample_pred_ball_global_y -
                               sample_global_ball_pos_xy[1])**2
                    mse_global.update(mse)

                print(
                    'Global stage: (x, y) - org: ({}, {}), gt = ({}, {}), prediction = ({}, {})'
                    .format(sample_org_ball_pos_xy[0],
                            sample_org_ball_pos_xy[1],
                            sample_global_ball_pos_xy[0],
                            sample_global_ball_pos_xy[1],
                            sample_pred_ball_global_x,
                            sample_pred_ball_global_y))

                # Process local ball stage
                if pred_ball_local is not None:
                    # Get target
                    sample_local_ball_pos_xy = local_ball_pos_xy[
                        sample_idx]  # Target
                    # Process the local stage
                    sample_pred_ball_local = pred_ball_local[sample_idx]
                    sample_pred_ball_local[sample_pred_ball_local <
                                           configs.thresh_ball_pos_mask] = 0.
                    sample_pred_ball_local_x = np.argmax(
                        sample_pred_ball_local[:w])
                    sample_pred_ball_local_y = np.argmax(
                        sample_pred_ball_local[w:])

                    # Calculate the MSE
                    if (sample_local_ball_pos_xy[0] >
                            0) and (sample_local_ball_pos_xy[1] > 0):
                        mse = (sample_pred_ball_local_x -
                               sample_local_ball_pos_xy[0])**2 + (
                                   sample_pred_ball_local_y -
                                   sample_local_ball_pos_xy[1])**2
                        mse_local.update(mse)

                    print(
                        'Local stage: (x, y) - gt = ({}, {}), prediction = ({}, {})'
                        .format(sample_local_ball_pos_xy[0],
                                sample_local_ball_pos_xy[1],
                                sample_pred_ball_local_x,
                                sample_pred_ball_local_y))

                # Process event stage
                if pred_events is not None:
                    sample_target_event = event_class[sample_idx].item()
                    vec_sample_target_event = np.zeros((2, ), dtype=np.int)
                    if sample_target_event < 2:
                        vec_sample_target_event[sample_target_event] = 1
                    sample_pred_event = (pred_events[sample_idx] >
                                         configs.event_thresh).astype(np.int)
                    print('Event stage: gt = {}, prediction: {}'.format(
                        sample_target_event, pred_events[sample_idx]))
                    diff = sample_pred_event - vec_sample_target_event
                    # Check correct or not
                    if np.sum(diff) != 0:
                        # Incorrect
                        acc_event.update(0)
                    else:
                        # Correct
                        acc_event.update(1)

                # Process segmentation stage
                if pred_seg is not None:
                    sample_target_seg = target_seg[sample_idx].transpose(
                        1, 2, 0)
                    sample_pred_seg = pred_seg[sample_idx].transpose(1, 2, 0)
                    sample_target_seg = sample_target_seg.astype(np.int)
                    sample_pred_seg = (sample_pred_seg >
                                       configs.seg_thresh).astype(np.int)

                    # Calculate the IoU
                    iou = 2 * np.sum(sample_target_seg * sample_pred_seg) / (
                        np.sum(sample_target_seg) + np.sum(sample_pred_seg) +
                        1e-9)
                    iou_seg.update(iou)
                    if configs.save_test_output:
                        fig, axes = plt.subplots(nrows=batch_size,
                                                 ncols=2,
                                                 figsize=(10, 5))
                        plt.tight_layout()
                        axes.ravel()
                        axes[2 * sample_idx].imshow(sample_target_seg * 255)
                        axes[2 * sample_idx + 1].imshow(sample_pred_seg * 255)
                        # title
                        target_title = 'target seg'
                        pred_title = 'pred seg'
                        if pred_events is not None:
                            target_title += ', is bounce: {}, is net: {}'.format(
                                vec_sample_target_event[0],
                                vec_sample_target_event[1])
                            pred_title += ', is bounce: {}, is net: {}'.format(
                                sample_pred_event[0], sample_pred_event[1])

                        axes[2 * sample_idx].set_title(target_title)
                        axes[2 * sample_idx + 1].set_title(pred_title)

                        plt.savefig(
                            os.path.join(
                                configs.saved_dir,
                                'batch_idx_{}_sample_idx_{}.jpg'.format(
                                    batch_idx, sample_idx)))

            if ((batch_idx + 1) % configs.print_freq) == 0:
                print(
                    'batch_idx: {} - Average acc_event: {}, iou_seg: {}, mse_global: {}, mse_local: {}'
                    .format(batch_idx, acc_event.avg, iou_seg.avg,
                            mse_global.avg, mse_local.avg))

            batch_time.update(time.time() - start_time)

            start_time = time.time()

    print('Average acc_event: {}, iou_seg: {}, mse_global: {}, mse_local: {}'.
          format(acc_event.avg, iou_seg.avg, mse_global.avg, mse_local.avg))
    print('Done testing')
コード例 #29
0
def train_one_epoch(train_dataloader, model, optimizer, lr_scheduler, epoch, configs, logger, tb_writer):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')

    progress = ProgressMeter(len(train_dataloader), [batch_time, data_time, losses],
                             prefix="Train - Epoch: [{}/{}]".format(epoch, configs.num_epochs))

    num_iters_per_epoch = len(train_dataloader)

    # switch to train mode
    model.train()
    start_time = time.time()
    for batch_idx, batch_data in enumerate(tqdm(train_dataloader)):
        data_time.update(time.time() - start_time)
        _, imgs, targets = batch_data
        global_step = num_iters_per_epoch * (epoch - 1) + batch_idx + 1

        batch_size = imgs.size(0)

        targets = targets.to(configs.device, non_blocking=True)
        imgs = imgs.to(configs.device, non_blocking=True)
        total_loss, outputs = model(imgs, targets)

        # For torch.nn.DataParallel case
        if (not configs.distributed) and (configs.gpu_idx is None):
            total_loss = torch.mean(total_loss)

        # compute gradient and perform backpropagation
        total_loss.backward()
        if global_step % configs.subdivisions == 0:
            optimizer.step()
            # Adjust learning rate
            lr_scheduler.step()
            # zero the parameter gradients
            optimizer.zero_grad()

        if configs.distributed:
            reduced_loss = reduce_tensor(total_loss.data, configs.world_size)
        else:
            reduced_loss = total_loss.data
        losses.update(to_python_float(reduced_loss), batch_size)
        # measure elapsed time
        # torch.cuda.synchronize()
        batch_time.update(time.time() - start_time)

        if tb_writer is not None:
            if (global_step % configs.tensorboard_freq) == 0:
                tensorboard_log = get_tensorboard_log(model)
                tensorboard_log['lr'] = lr_scheduler.get_lr()[0] * configs.batch_size * configs.subdivisions
                tensorboard_log['avg_loss'] = losses.avg
                tb_writer.add_scalars('Train', tensorboard_log, global_step)

        # Log message
        if logger is not None:
            if (global_step % configs.print_freq) == 0:
                logger.info(progress.get_message(batch_idx))

        start_time = time.time()
コード例 #30
0
ファイル: train.py プロジェクト: lxtGH/PFSegNets
def train(train_loader, net, optim, curr_epoch, writer):
    """
    Runs the training loop per epoch
    train_loader: Data loader for train
    net: thet network
    optimizer: optimizer
    curr_epoch: current epoch
    writer: tensorboard writer
    return:
    """
    net.train()

    train_main_loss = AverageMeter()
    curr_iter = curr_epoch * len(train_loader)

    for i, data in enumerate(train_loader):
        edges = None
        if args.joint_edge_loss_pfnet:
            inputs, gts, bodys, edges, _img_name = data
        else:
            inputs, gts, _img_name = data

        batch_pixel_size = inputs.size(0) * inputs.size(2) * inputs.size(3)

        inputs, gts = inputs.cuda(), gts.cuda()

        optim.zero_grad()
        if args.joint_edge_loss_pfnet:
            main_loss_dic = net(inputs, gts=(gts, edges))
            main_loss = 0.0
            for v in main_loss_dic.values():
                main_loss = main_loss + v
        else:
            main_loss = net(inputs, gts=gts)

        if args.apex:
            log_main_loss = main_loss.clone().detach_()
            torch.distributed.all_reduce(log_main_loss,
                                         torch.distributed.ReduceOp.SUM)
            log_main_loss = log_main_loss / args.world_size
        else:
            main_loss = main_loss.mean()
            log_main_loss = main_loss.clone().detach_()

        train_main_loss.update(log_main_loss.item(), batch_pixel_size)
        if args.fp16:
            with amp.scale_loss(main_loss, optim) as scaled_loss:
                scaled_loss.backward()
        else:
            if not torch.isfinite(main_loss).all():
                raise FloatingPointError(
                    "Loss became infinite or NaN at iteration={}!\nloss_dict = {}"
                    .format(curr_iter, main_loss))
            main_loss.backward()

        optim.step()

        curr_iter += 1

        if args.local_rank == 0 and i % args.print_freq == 0:
            if args.joint_edge_loss_pfnet:
                msg = f'[epoch {curr_epoch}], [iter {i + 1} / {len(train_loader)}], '
                msg += '[seg_main_loss:{:0.5f}]'.format(
                    main_loss_dic['seg_loss'])
                for j in range(3):
                    temp_msg = '[layer{}:, [edge loss {:0.5f}] '.format(
                        (3 - j), main_loss_dic[f'edge_loss_layer{3-j}'])
                    msg += temp_msg
                msg += ', [lr {:0.5f}]'.format(optim.param_groups[-1]['lr'])
            else:
                msg = '[epoch {}], [iter {} / {}], [train main loss {:0.6f}], [lr {:0.6f}]'.format(
                    curr_epoch, i + 1, len(train_loader), train_main_loss.avg,
                    optim.param_groups[-1]['lr'])

            logging.info(msg)

            # Log tensorboard metrics for each iteration of the training phase
            writer.add_scalar('training/loss', (train_main_loss.val),
                              curr_iter)
            writer.add_scalar('training/lr', optim.param_groups[-1]['lr'],
                              curr_iter)

        if i > 5 and args.test_mode:
            return