def collision_validate(grasp_pres, mask_height, rgb, img_name):
    # 1代表有碰撞
    flag = 1
    for grasp_pre in grasp_pres:
        gr = grasp_pre.as_gr
        img = show_grasp(rgb, gr, 255)
        # 如果从高度图上看这里的高度为0的话,就跳过了,肯定不行的.
        flag = collision_check(mask_height,
                               (gr.center, gr.angle, gr.width, gr.length))
        imsave(img_name.replace('f', str(flag)), img)
        if not flag:
            break
    return flag
Exemplo n.º 2
0
    def get_patch(self,prediction,mask_height):
        gt_patches = []
        pre_patches = []
        prob, cos_img, sin_img, width_img = prediction
        # 获取batch size,后面要用
        batch_size = prob.size()[0]
        ang_out = (torch.atan2(sin_img, cos_img) / 2.0)
        width_out= width_img * 150.0
        prob_g = T.GaussianBlur(kernel_size = (3,3),sigma = 2)(prob)
        ang_g = T.GaussianBlur(kernel_size = (3,3),sigma = 2)(ang_out)
        width_g = T.GaussianBlur(kernel_size = (3,3),sigma = 1)(width_out)

        position = torch.max(prob_g.reshape(-1,90000),dim = 1)[1]

        # 求得8张图中最亮点的中心坐标
        x = position % 300
        y = position // 300
        # 以其为中心,拓展成栅格矩阵
        # # 因为上面的采样出来每个点对应五个值,所以这边就重复5次,想要更密的话可以调
        # offset_x = torch.arange(-6,7,3).repeat(8*5).cuda()
        # offset_y = torch.arange(-6,7,3).repeat_interleave(5).repeat(8).cuda()
        # NOTE 上面这里得改,肯定要从最把握最大的中间开始检查,不行的话再检查边上啊,现在这是从边上检查的,浪费太多时间了
        steps = 49
        offset_x = torch.tensor([0,2,-2,4,-4,6,-6]).repeat(8*7).cuda()
        offset_y = torch.tensor([0,2,-2,4,-4,6,-6]).repeat_interleave(7).repeat(8).cuda()
        # 扩展x和y并与offset相加
        expand_x = x.repeat_interleave(steps) + offset_x
        expand_y = y.repeat_interleave(steps) + offset_y
        expand_x = torch.clip(expand_x,6,293)
        expand_y = torch.clip(expand_y,6,293)

        indice0 = torch.arange(0,8).repeat_interleave(steps)
        indice1 = indice0.new_full((len(indice0),),0)

        # 索引获得每个位置对应的角度和宽度
        ang = ang_g[(indice0,indice1,expand_y,expand_x)]
        width = width_g[(indice0,indice1,expand_y,expand_x)]
        length = width / 2

        # 组合参数并分组,为后面碰撞检测做准备
        boxes_params = torch.stack((expand_y,expand_x,ang,width,length)).t().reshape(8,-1,5)
        # 每张图要转25个不同的角度,然后裁剪出来做检测,一共八张图,也就是说要转200次,只转图像就行了,gr不用转,然后在这个中心来进行裁剪
        # 目前只能想到用for循环来做
        indexes, batch_edges, batch_edges_left, batch_edges_right, batch_edges_top, batch_edges_bottom,directions  = get_indexes(mask_height,boxes_params,batch_size,steps)
        # 显示最终选定的抓取 NOTE 调试用
        for i, image in enumerate(mask_height):
            index = indexes[i]
            y = boxes_params[i][index][0].cpu().data.numpy()
            x = boxes_params[i][index][1].cpu().data.numpy()
            angle = boxes_params[i][index][2].cpu().data.numpy() 
            width = boxes_params[i][index][3].cpu().data.numpy()

            gr = Grasp_cpaw((y,x),angle,width/2,width)
            gr_img = show_grasp(image.cpu().data.numpy()[0],gr.as_gr,50)
            plt.subplot(2,4,i+1)
            plt.imshow(gr_img)
        plt.show()
        # 对每一张图生成patch
        for i in range(batch_size):
            # 角度还是用位置更新之前的,因为更新后的位置角度是否合适没有经过验证的
            cos_patch = cos_img[i][0][y,x].cpu().data.numpy()
            sin_patch = sin_img[i][0][y,x].cpu().data.numpy()
            y = int(boxes_params[i][indexes[i]][0].cpu().data.numpy())
            x = int(boxes_params[i][indexes[i]][1].cpu().data.numpy())
            
            angle = ang_g[i][0][y,x].cpu().data.numpy()
            width_patch = width_img[i][0][y,x].cpu().data.numpy()

            selected_edge = batch_edges[i][indexes[i]]
            left_right_diff = abs(batch_edges_left[i][indexes[i]]-batch_edges_right[i][indexes[i]])
            top_bottom_diff = abs(batch_edges_top[i][indexes[i]]-batch_edges_bottom[i][indexes[i]])

            y,x,scale = get_patch_params(i,y,x,selected_edge,width_patch,angle,directions,left_right_diff,top_bottom_diff)

            # 先裁剪出原始的预测图
            pre_patch = torch.stack([prob[i][0][y-2:y+3,x-2:x+3],
                        cos_img[i][0][y-2:y+3,x-2:x+3],
                        sin_img[i][0][y-2:y+3,x-2:x+3],
                        width_img[i][0][y-2:y+3,x-2:x+3]])
            
            gt_patch = torch.stack([pre_patch[0].new_full((5,5),1),
                                    pre_patch[1].new_full((5,5),float(cos_patch)),
                                    pre_patch[2].new_full((5,5),float(sin_patch)),
                                        pre_patch[3].new_full((5,5),float((width_patch*scale)))])

            image = mask_height[i]
            index = indexes[i]
            angle = boxes_params[i][index][2].cpu().data.numpy() 
            width = boxes_params[i][index][3].cpu().data.numpy()
            old_y = int(boxes_params[i][indexes[i]][0].cpu().data.numpy())
            old_x = int(boxes_params[i][indexes[i]][1].cpu().data.numpy())
            gr = Grasp_cpaw((old_y,old_x),angle,width/2,width)
            gr_img = show_grasp(image.cpu().data.numpy()[0],gr.as_gr,50)
            plt.subplot(121)
            plt.imshow(gr_img)
            plt.subplot(122)
            gr = Grasp_cpaw((y,x),angle,width*scale/2,width*scale)
            gr_img = show_grasp(image.cpu().data.numpy()[0],gr.as_gr,50)
            plt.imshow(gr_img)
            plt.show()
            for j in range(4):
                plt.subplot(2,4,j+1)
                plt.imshow(pre_patch[j].cpu().data.numpy())
            for j in range(4):
                plt.subplot(2,4,j+5)
                plt.imshow(gt_patch[j].cpu().data.numpy())
            plt.show()
            pre_patches.append(pre_patch)
            gt_patches.append(gt_patch)
        prob_s = torch.cat((mask_height,prob_g,ang_g,width_g),dim = 0)
        for i in range(32):
            plt.subplot(4,8,i+1)
            plt.imshow(prob_s[i][0].cpu().data.numpy())
        plt.show()

        return torch.stack(pre_patches),torch.stack(gt_patches)
def check_false_positive(net, device, val_data, batches_per_epoch):

    val_result = {
        'correct': 0,
        'failed': 0,
        'loss': 0,
        'losses': {},
        'acc': 0.0,
        'false_positive': 0,
        'true_positive': 0,
        'real_acc': 0.0
    }
    # 设置网络进入验证模式
    net.eval()

    with torch.no_grad():
        batch_idx = 0
        while batch_idx < (batches_per_epoch):
            for x, y, idx, rot, zoom_factor in val_data:
                batch_idx += 1
                if batch_idx >= batches_per_epoch:
                    break
                xc = x.to(device)
                yc = [yy.to(device) for yy in y]

                lossdict = net.compute_loss(xc, yc)

                q_out, ang_out, width_out = post_process(
                    lossdict['pred']['pos'], lossdict['pred']['cos'],
                    lossdict['pred']['sin'], lossdict['pred']['width'])
                grasp_pres = detect_grasps(q_out, ang_out, width_out)
                grasps_true = val_data.dataset.get_raw_grasps(
                    idx, rot, zoom_factor)

                result = 0
                for grasp_pre in grasp_pres:
                    if max_iou(grasp_pre, grasps_true) > 0.25:
                        result = 1
                        break

                if result:
                    val_result['correct'] += 1
                    # 这里来检查是否存在False-Positive
                    # edges = val_data.dataset.get_edges(idx, rot, zoom_factor) # 边缘检测避障用
                    depth_img = val_data.dataset.get_mask_d(
                        idx, rot, zoom_factor)  # 深度检测避障用

                    # 读取当前预测的抓取
                    gr = grasp_pre.as_gr
                    try:
                        # _, flag, _ ,_ = correct_grasp(edges=edges*255, gr=gr,idx = idx, edge_width= 5) # 边缘检测避障用
                        # collision = not flag # 边缘检测避障用
                        collision = detect_dep(depth_img=depth_img,
                                               gr0=gr,
                                               edge_width=5)  # 深度检测避障用
                    except:
                        print(idx.cpu().data.numpy()[0], '碰撞检测报错了')
                        continue
                    # 检测是否是真正正确的抓取
                    if collision:
                        val_result['false_positive'] += 1
                        # 可视化一下这个,看判断是否正确
                        rgb = val_data.dataset.get_rgb(idx,
                                                       rot,
                                                       zoom_factor,
                                                       normalize=False)
                        img = show_grasp(rgb, gr, 255)
                        img_name = '8.jacquard_code_origin/false_positive/{0}_{1}_{2}.png'.format(
                            idx.cpu().data.numpy()[0],
                            rot.cpu().data.numpy()[0],
                            zoom_factor.cpu().data.numpy()[0])
                        imsave(img_name, img)
                    else:
                        val_result['true_positive'] += 1
                else:
                    val_result['failed'] += 1

        logging.info(time.ctime())
        acc = val_result['correct'] / \
            (val_result['correct']+val_result['failed'])
        real_acc = val_result['true_positive'] / \
            (val_result['correct']+val_result['failed'])
        logging.info('acc:{}'.format(acc))
        logging.info('real_acc:{}'.format(real_acc))
        val_result['acc'] = acc
        val_result['real_acc'] = real_acc

        with open('8.jacquard_code_origin/result.txt', 'a') as f:
            f.write(time.ctime())
            f.write('\n ')
            f.write(
                'correct:{0}\n failed:{1}\n acc:{2}\n true_positive:{3}\n false_positive:{4}\n real_acc:{5}\n'
                .format(val_result['correct'], val_result['failed'],
                        val_result['acc'], val_result['true_positive'],
                        val_result['false_positive'], val_result['real_acc']))
    return (val_result)
def classify_false_positive(net, net_c, device, val_data, optimizer,
                            batches_per_epoch):

    # 设置网络进入训练模式
    net.train()
    net_c.train()

    classify_result = {
        'correct': 0,
        'failed': 0,
        'loss': 0,
        'losses': {},
        'acc': 0.0,
        'c_correct': 0,
        'c_failed': 0,
        'c_acc': 0.0
    }

    batch_idx = 0
    while batch_idx < (batches_per_epoch):
        for x, y, idx, rot, zoom_factor in val_data:
            batch_idx += 1
            if batch_idx >= batches_per_epoch:
                break
            xc = x.to(device)
            yc = [yy.to(device) for yy in y]

            lossdict = net.compute_loss(xc, yc)
            loss = lossdict['loss']
            q_out, ang_out, width_out = post_process(lossdict['pred']['pos'],
                                                     lossdict['pred']['cos'],
                                                     lossdict['pred']['sin'],
                                                     lossdict['pred']['width'])
            grasp_pres = detect_grasps(q_out, ang_out, width_out)
            grasps_true = val_data.dataset.get_raw_grasps(
                idx, rot, zoom_factor)

            result = 0
            for grasp_pre in grasp_pres:
                if max_iou(grasp_pre, grasps_true) > 0.25:
                    result = 1
                    break

            if result:  # 只有在预测得到可行抓取的时候才能执行这一个.
                classify_result['correct'] += 1
                # 这里来检查是否存在False-Positive
                depth_img = val_data.dataset.get_mask_d(idx, rot,
                                                        zoom_factor)  # 深度检测避障用
                # 读取当前预测的抓取
                gr = grasp_pre.as_gr
                try:
                    rgb = val_data.dataset.get_rgb(idx,
                                                   rot,
                                                   zoom_factor,
                                                   normalize=False)
                    img = show_grasp(rgb, gr, 255)
                    img_name = 'false_positive/{0}_{1}_{2}.png'.format(
                        idx.cpu().data.numpy()[0],
                        rot.cpu().data.numpy()[0],
                        zoom_factor.cpu().data.numpy()[0])
                    imsave(img_name, img)
                    collision = detect_dep(depth_img=depth_img,
                                           gr0=gr,
                                           edge_width=5)  # 深度检测避障用
                except:
                    print(idx.cpu().data.numpy()[0], '碰撞检测报错了')
                    continue

                # 处理获得用于碰撞检测分类器的输入数据
                features = lossdict['features'].cpu().data.numpy().squeeze()
                x = get_gr_feature_map(features, gr1=gr)
                x = torch.Tensor(x)
                x_gr = x.to(device)
                y = torch.Tensor([collision])
                y_gr = y.to(device)

                loss_c_dict = net_c.compute_loss(x_gr, y_gr)

                loss_c = loss_c_dict['loss']

                pred = loss_c_dict['pred'].cpu().data.numpy()
                if pred > 0.80 and collision or pred < 0.20 and not collision:
                    classify_result['c_correct'] += 1
                else:
                    classify_result['c_failed'] += 1

                optimizer.zero_grad()
                loss.backward()
                loss_c.backward()
                optimizer.step()
            else:
                classify_result['failed'] += 1

    logging.info(time.ctime())
    acc = classify_result['correct'] / \
        (classify_result['correct']+classify_result['failed'])
    logging.info('acc:{}'.format(acc))
    classify_result['acc'] = acc
    classify_result['c_acc'] = classify_result['c_correct'] / (
        classify_result['c_correct'] + classify_result['c_failed'])

    logging.info('{}/{} classify acc :{}'.format(
        classify_result['c_correct'],
        (classify_result['c_correct'] + classify_result['c_failed']),
        classify_result['c_acc']))

    return 0