Esempio n. 1
0
def run(**kwargs):
    print('a')
    #get configuration
    try:
        config = importlib.import_module('config.'+kwargs['config']);
        opt = config.__dict__;
        for k in kwargs.keys():
            if not kwargs[k] is None:
                opt[k] = kwargs[k];
    except Exception as e:
        print(e);
        traceback.print_exc();
        exit();
    #get network
    try:
        m = importlib.import_module('net.'+opt['net']);
        net = m.Net(**opt);
        if torch.cuda.is_available():
            net = net.cuda();
    except Exception as e:
        print(e);
        traceback.print_exc();
        exit();
    #get dataset
    try:
        m = importlib.import_module('util.dataset.'+opt['dataset']);
        train_data = m.Data(opt,True);
        val_data = m.Data(opt,False);
        train_load = DataLoader(train_data,batch_size=opt['batch_size'],shuffle=True,num_workers=opt['workers']);
        val_load = DataLoader(val_data,batch_size=opt['batch_size'],shuffle=False,num_workers=opt['workers']);
    except Exception as e:
        print(e);
        traceback.print_exc();
        exit();
        
    if opt['model']!='':
        partial_restore(net,opt['model']);
        print("Previous weights loaded");
    if 'train' in opt['user_key']:
        load = train_load;
    else:
        load = val_load;
    if opt['model']!='':
        outdir = os.path.dirname(opt['model'])+os.sep+'view_'+opt['user_key'];
        if not os.path.exists(outdir):
            os.mkdir(outdir);
    for i, data in enumerate(load,0):
        print(i,'/',len(train_data)//opt['batch_size']);
        data2cuda(data);
        net.eval();
        with torch.no_grad():
            out = net(data);
        img = data[0].data.cpu().numpy();
        box2d_src = data[1].data.cpu().numpy();
        box3d_src = data[2].data.cpu().numpy();
        box2d_tgt = data[3].data.cpu().numpy();
        box3d_tgt = data[4].data.cpu().numpy();
        r = data[5].data.cpu().numpy();
        gts = data[6].data.cpu().numpy();
        y = out['y'].data.cpu().numpy();

        tri = box_face;
        num = box3d_src.shape[0];
        col = 4;
        row = num // col;
        for ri in range(row):
            for cj in range(col):
                ni = ri*col+cj;
                fig = plt.figure(figsize=(48,16));
                #
                ax = fig.add_subplot(132,projection='3d');
                ax.view_init(elev=20, azim=0)
                ax.plot_trisurf(box3d_tgt[ni,...,0],box3d_tgt[ni,...,1],tri,box3d_tgt[ni,...,2],color=(0,0,1,0.1));
                ax.plot_trisurf(box3d_src[ni,...,0],box3d_src[ni,...,1],tri,box3d_src[ni,...,2],color=(0,1,0,0.1));
                ygt = gts[ni,...];
                ygt *= np.pi;
                ygt[1] *= 2;
                c3d1 = recon(box3d_src[ni,...],r[ni,...],ygt);
                ax.plot(c3d1[:,0],c3d1[:,1],c3d1[:,2],color='k');
                ymap = y[ni,...];
                ymap *= np.pi;
                ymap[1] *= 2;
                c3d2 = recon(box3d_src[ni,...],r[ni,...],ymap);
                ax.plot(c3d2[:,0],c3d2[:,1],c3d2[:,2],color='r');
                ax.scatter(box3d_src[ni,0:4,0],box3d_src[ni,0:4,1],box3d_src[ni,0:4,2],color='b',marker='*');
                ax.scatter(box3d_src[ni,4:8,0],box3d_src[ni,4:8,1],box3d_src[ni,4:8,2],color='c',marker='*');
                ax.scatter(box3d_tgt[ni,0:4,0],box3d_tgt[ni,0:4,1],box3d_tgt[ni,0:4,2],color='k',marker='x');
                ax.scatter(box3d_tgt[ni,4:8,0],box3d_tgt[ni,4:8,1],box3d_tgt[ni,4:8,2],color='r',marker='x');
                ax.set_aspect('equal', adjustable='box');
                #
                ax = fig.add_subplot(133,projection='3d');
                ax.view_init(elev=20, azim=90)
                ax.plot_trisurf(box3d_tgt[ni,...,0],box3d_tgt[ni,...,1],tri,box3d_tgt[ni,...,2],color=(0,0,1,0.1));
                ax.plot_trisurf(box3d_src[ni,...,0],box3d_src[ni,...,1],tri,box3d_src[ni,...,2],color=(0,1,0,0.1));
                ax.plot(c3d1[:,0],c3d1[:,1],c3d1[:,2],color='k');
                ax.plot(c3d2[:,0],c3d2[:,1],c3d2[:,2],color='r');
                ax.scatter(box3d_src[ni,0:4,0],box3d_src[ni,0:4,1],box3d_src[ni,0:4,2],color='b',marker='*');
                ax.scatter(box3d_src[ni,4:8,0],box3d_src[ni,4:8,1],box3d_src[ni,4:8,2],color='c',marker='*');
                ax.scatter(box3d_tgt[ni,0:4,0],box3d_tgt[ni,0:4,1],box3d_tgt[ni,0:4,2],color='k',marker='x');
                ax.scatter(box3d_tgt[ni,4:8,0],box3d_tgt[ni,4:8,1],box3d_tgt[ni,4:8,2],color='r',marker='x');
                ax.set_aspect('equal', adjustable='box');
                #
                ax = fig.add_subplot(131);
                ax.set_aspect('equal', adjustable='box');
                ax.imshow(img[ni,...]);
                ax.scatter(box2d_src[ni,0:4,0],box2d_src[ni,0:4,1],color='b',marker='*');
                ax.scatter(box2d_src[ni,4:8,0],box2d_src[ni,4:8,1],color='c',marker='*');
                ax.scatter(box2d_tgt[ni,0:4,0],box2d_tgt[ni,0:4,1],color='k',marker='x');
                ax.scatter(box2d_tgt[ni,4:8,0],box2d_tgt[ni,4:8,1],color='r',marker='x');
                ax.set_aspect('equal', adjustable='box');
                c2d1 = proj(mv(c3d1));
                c2d2 = proj(mv(c3d2));
                ax.plot(c2d1[:,0],c2d1[:,1],color='k');
                ax.plot(c2d2[:,0],c2d2[:,1],color='r');
                if opt['model']!='':
                    plt.savefig(os.path.join(outdir,"_%04d_%04d.png"%(i,ni)));
                if opt['ply']:
                    plt.show();
                plt.close(fig);
            
    #run the code
    '''
Esempio n. 2
0
def run(**kwargs):
    global iternum
    #get configuration
    try:
        config = importlib.import_module('config.' + kwargs['config'])
        opt = config.__dict__
        for k in kwargs.keys():
            if not kwargs[k] is None:
                opt[k] = kwargs[k]
    except Exception as e:
        print(e)
        traceback.print_exc()
        exit()
    iternum = opt['nepoch']
    #get network
    try:
        m = importlib.import_module('net.' + opt['net'])
        net = m.Net(**opt)
        if torch.cuda.is_available():
            net = net.cuda()
    except Exception as e:
        print(e)
        traceback.print_exc()
        exit()
    #get dataset
    try:
        m = importlib.import_module('util.dataset.' + opt['dataset'])
        train_data = m.Data(opt, 'train')
        val_data = m.Data(opt, 'test')
        train_load = DataLoader(train_data,
                                batch_size=opt['batch_size'],
                                shuffle=True,
                                num_workers=opt['workers'])
        val_load = DataLoader(val_data,
                              batch_size=opt['batch_size'],
                              shuffle=False,
                              num_workers=opt['workers'])
    except Exception as e:
        print(e)
        traceback.print_exc()
        exit()

    print('bs:', opt['batch_size'])

    if opt['model'] != '':
        partial_restore(net, opt['model'])
        print("Previous weights loaded")

    if opt['model'] != '':
        outdir = os.path.dirname(opt['model']) + os.sep + 'view'
        if not os.path.exists(outdir):
            os.mkdir(outdir)

    #run the code
    optim = eval('optim.' + opt['optim'])(config.parameters(net),
                                          lr=opt['lr'],
                                          weight_decay=opt['weight_decay'])
    tri = box_face
    fidx = tri

    T = np.dtype([("n", np.uint8), ("i0", np.int32), ('i1', np.int32),
                  ('i2', np.int32)])
    face = np.zeros(shape=[2 * fidx.shape[0]], dtype=T)
    for i in range(2 * fidx.shape[0]):
        if i < fidx.shape[0]:
            face[i] = (3, fidx[i, 0], fidx[i, 1], fidx[i, 2])
        else:
            face[i] = (3, fidx[i - fidx.shape[0], 0] + 8,
                       fidx[i - fidx.shape[0], 1] + 8,
                       fidx[i - fidx.shape[0], 2] + 8)

    for i, data in enumerate(val_load, 0):
        print(i, '/', len(val_data))
        data2cuda(data)
        img = data[0].data.cpu().numpy()
        msks = data[1].data.cpu().numpy()
        mskt = data[2].data.cpu().numpy()
        ygt = data[3].data.cpu().numpy()
        vgt = data[4].data.cpu().numpy()
        sgt = data[-3].data.cpu().numpy()
        net.eval()
        with torch.no_grad():
            out = net(data)
        acc = config.accuracy(data, out)
        id = data[-2]
        tb = out['tb'].data.cpu().numpy()
        sb = out['sb'].data.cpu().numpy()
        #y = out['y'].data.cpu().numpy();
        cat = data[-1]
        err = acc['t'].data.cpu().numpy()
        #
        for tagi, tag in enumerate(id):
            cpath = os.path.join(
                outdir, '_' + cat[tagi] + '_%04d' % i + '_%02d' % tagi)
            if not os.path.exists(cpath):
                os.mkdir(cpath)
            im = Image.fromarray((img[tagi, ...] * 255).astype(np.uint8),
                                 mode='RGB')
            im.save(os.path.join(cpath, '_000_img.png'))
            if opt['mode'] == 'full':
                im.save(os.path.join(cpath, '_000_input.png'))
            elif opt['mode'] == 'part':
                iim = img[tagi, ...] * (
                    (msks[tagi, ...] + mskt[tagi, ...]).reshape(224, 224, 1))
                iim = Image.fromarray((iim * 255).astype(np.uint8), mode='RGB')
                iim.save(os.path.join(cpath, '_000_input.png'))
            mks = Image.fromarray((msks[tagi, ...] * 255).astype(np.uint8),
                                  mode='L')
            mks.save(os.path.join(cpath, '_000_msks.png'))
            mkt = Image.fromarray((mskt[tagi, ...] * 255).astype(np.uint8),
                                  mode='L')
            mkt.save(os.path.join(cpath, '_000_mskt.png'))
            ptsa, ptsb = parse(vgt[tagi, ...])
            center = vgt[tagi, 12:15]
            print('L2:%f' % (err[tagi, ...]),
                  file=open(os.path.join(cpath, '_000_log.txt'), 'w'))
            print('id:%s' % (id[tagi]),
                  file=open(os.path.join(cpath, '_000_meta.txt'), 'w'))
            st = sgt[tagi, :]
            st = st[np.newaxis, :]
            ptgt = np.concatenate([ptsa, ptsb], axis=0)
            ptgt += st
            write_ply(os.path.join(cpath, '_000_gt.ply'),
                      points=pd.DataFrame(ptgt),
                      faces=pd.DataFrame(face))
            ptout = np.concatenate([sb[tagi, ...], tb[tagi, ...]], axis=0)
            ptout += st
            write_ply(os.path.join(cpath, '_000_out.ply'),
                      points=pd.DataFrame(ptout),
                      faces=pd.DataFrame(face))
Esempio n. 3
0
def run(**kwargs):
    global iternum
    #get configuration
    try:
        config = importlib.import_module('config.' + kwargs['config'])
        opt = config.__dict__
        for k in kwargs.keys():
            if not kwargs[k] is None:
                opt[k] = kwargs[k]
    except Exception as e:
        print(e)
        traceback.print_exc()
        exit()
    iternum = opt['nepoch']
    #get network
    try:
        m = importlib.import_module('net.' + opt['net'])
        net = m.Net(**opt)
        if torch.cuda.is_available():
            net = net.cuda()
    except Exception as e:
        print(e)
        traceback.print_exc()
        exit()
    #get dataset
    try:
        m = importlib.import_module('util.dataset.' + opt['dataset'])
        train_data = m.Data(opt, True)
        val_data = m.Data(opt, False)
        train_load = DataLoader(train_data,
                                batch_size=opt['batch_size'],
                                shuffle=True,
                                num_workers=opt['workers'])
        val_load = DataLoader(val_data,
                              batch_size=opt['batch_size'],
                              shuffle=False,
                              num_workers=opt['workers'])
    except Exception as e:
        print(e)
        traceback.print_exc()
        exit()

    if opt['model'] != '':
        partial_restore(net, opt['model'])
        print("Previous weights loaded")
    if 'train' in opt['user_key']:
        load = train_load
        print(train_data.datapath[0])
    else:
        load = val_load
        print(val_data.datapath[0])

    if opt['model'] != '':
        outdir = os.path.dirname(
            opt['model']) + os.sep + 'view_' + opt['user_key']
        if not os.path.exists(outdir):
            os.mkdir(outdir)

    for i, data in enumerate(load, 0):
        data2cuda(data)
        d = data
        #
        net.eval()
        with torch.no_grad():
            out = net(data)
        #
        img = data[0].data.cpu().numpy()
        box_src = data[1].data.cpu().numpy()
        box_tgt = data[2].data.cpu().numpy()
        center = data[3].data.cpu().numpy()
        r = data[4].data.cpu().numpy()
        gt = data[5].data.cpu().numpy()
        yout = out['y'].data.cpu().numpy()
        map = out['dmap'].permute(0, 2, 3, 1).data.cpu().numpy()
        #run the code
        #optim = eval('optim.'+opt['optim'])(config.parameters(net),lr=opt['lr'],weight_decay=opt['weight_decay']);
        tri = box_face
        global pv
        num = box_src.shape[0]
        col = int(np.sqrt(num))
        row = num // col
        for ri in range(row):
            for cj in range(col):
                ni = ri * col + cj
                fig = plt.figure(figsize=(48, 16))
                #===========================================
                y = gt[ni, ...].copy()
                y *= np.pi
                y[1] *= 2
                c3d = recon(center[ni, ...], r[ni, ...], y)
                #===========================================
                ymap = yout[ni, ...]
                ymap *= np.pi
                ymap[1] *= 2
                c3do = recon(center[ni, ...], r[ni, ...], ymap)
                c2do = proj(mv(c3do))
                #===========================================
                ax = fig.add_subplot(1, 5, 1)
                ax.imshow(img[ni, ...])
                c2d = proj(mv(c3d))
                line = ax.plot(c2do[:, 0], c2do[:, 1], color='r')
                pv.extend(line)
                ax.plot(c2d[:, 0], c2d[:, 1], color='k')
                #===========================================
                ax = fig.add_subplot(1, 5, 2)
                im1 = norm_im(map[ni, :, :, :3])
                im1 = ax.imshow(im1)
                pv.append(im1)
                #===========================================
                ax = fig.add_subplot(1, 5, 3)
                im2 = norm_im(map[ni, :, :, 3:])
                im2 = ax.imshow(im2)
                pv.append(im2)
                #===========================================
                ax = fig.add_subplot(1, 5, 4, projection='3d')
                ax.view_init(elev=20, azim=90)
                ax.set_aspect('equal', adjustable='box')
                ax.plot_trisurf(box_tgt[ni, ..., 0],
                                box_tgt[ni, ..., 1],
                                tri,
                                box_tgt[ni, ..., 2],
                                color=(0, 0, 1, 0.1))
                ax.plot_trisurf(box_src[ni, ..., 0],
                                box_src[ni, ..., 1],
                                tri,
                                box_src[ni, ..., 2],
                                color=(0, 1, 0, 0.1))
                line = ax.plot(c3do[:, 0], c3do[:, 1], c3do[:, 2], color='r')
                #==========================================
                pv.extend(line)
                ax.plot(c3d[:, 0], c3d[:, 1], c3d[:, 2], color='k')
                #==========================================
                ax = fig.add_subplot(1, 5, 5, projection='3d')
                ax.set_aspect('equal', adjustable='box')
                ax.plot_trisurf(box_tgt[ni, ..., 0],
                                box_tgt[ni, ..., 1],
                                tri,
                                box_tgt[ni, ..., 2],
                                color=(0, 0, 1, 0.1))
                ax.plot_trisurf(box_src[ni, ..., 0],
                                box_src[ni, ..., 1],
                                tri,
                                box_src[ni, ..., 2],
                                color=(0, 1, 0, 0.1))
                line = ax.plot(c3do[:, 0], c3do[:, 1], c3do[:, 2], color='r')
                #===========================================
                pv.extend(line)
                ax.plot(c3d[:, 0], c3d[:, 1], c3d[:, 2], color='k')
                #===========================================
                if opt['model'] != '':
                    plt.savefig(
                        os.path.join(outdir, "_%04d_%04d.png" % (i, ni)))
                if opt['ply']:
                    plt.show()
                plt.close(fig)
                #============

    #run the code
    '''
Esempio n. 4
0
def run(**kwargs):
    #get configuration
    try:
        config = importlib.import_module('config.'+kwargs['config']);
        opt = config.__dict__;
        for k in kwargs.keys():
            if not kwargs[k] is None:
                opt[k] = kwargs[k];
    except Exception as e:
        print(e);
        traceback.print_exc();
        exit();
    #get network
    try:
        m = importlib.import_module('net.'+opt['net']);
        net = m.Net(**opt);
        if torch.cuda.is_available():
            net = net.cuda();
    except Exception as e:
        print(e);
        traceback.print_exc();
        exit();
    #get dataset
    try:
        m = importlib.import_module('util.dataset.'+opt['dataset']);
        train_data = m.Data(opt,opt['user_key']);
        train_load = DataLoader(train_data,batch_size=opt['batch_size'],shuffle=False,num_workers=opt['workers']);
    except Exception as e:
        print(e);
        traceback.print_exc();
        exit();
        
    if opt['model']!='':
        partial_restore(net,opt['model']);
        print("Previous weights loaded");
           
    if opt['model']!='':
        outdir = os.path.dirname(opt['model'])+os.sep+'view';
        if not os.path.exists(outdir):
            os.mkdir(outdir); 
    #
    root = os.path.join(opt['data_path'],'test');
    cat_lst = os.listdir(root);
    for c in cat_lst:
        path = os.path.join(root,c);
        cout = os.path.join(outdir,'_'+c);
        if not os.path.exists(cout):
            os.mkdir(cout);
        if os.path.isdir(path):
            f_lst = os.listdir(path);
            cnt = 0;
            for i,f in enumerate(f_lst):
                if f.endswith('.h5'):
                    fopath = os.path.join(cout,'_%04d'%cnt);
                    h5f = h5py.File(os.path.join(path,f),'r');
                    if not os.path.exists(fopath):
                        os.mkdir(fopath);
                    img = np.array(h5f['img']);
                    Image.fromarray((img*255.0).astype(np.uint8),mode='RGB').save(os.path.join(fopath,'_input.png'));
                    msk = np.array(h5f['msk']);
                    touch = np.array(h5f['touch']);
                    box = np.array(h5f['box']);
                    write_box(box,os.path.join(fopath,'_gt.ply'));
                    obox = infer_box(net,img,msk,touch,box);
                    write_box(obox,os.path.join(fopath,'_out.ply'));
                    h5f.close();
                    cnt += 1;
                if cnt > 10:
                    break;
Esempio n. 5
0
def run(**kwargs):
    #get configuration
    try:
        config = importlib.import_module('config.' + kwargs['config'])
        opt = config.__dict__
        for k in kwargs.keys():
            if not kwargs[k] is None:
                opt[k] = kwargs[k]
    except Exception as e:
        print(e)
        traceback.print_exc()
        exit()
    #get network
    try:
        m = importlib.import_module('net.' + opt['net'])
        net = m.Net(**opt)
        if torch.cuda.is_available():
            net = net.cuda()
    except Exception as e:
        print(e)
        traceback.print_exc()
        exit()
    #get dataset
    try:
        m = importlib.import_module('util.dataset.' + opt['dataset'])
        train_data = m.Data(opt, True)
        train_load = DataLoader(train_data,
                                batch_size=opt['batch_size'],
                                shuffle=False,
                                num_workers=opt['workers'])
    except Exception as e:
        print(e)
        traceback.print_exc()
        exit()

    if opt['model'] != '':
        partial_restore(net, opt['model'])
        print("Previous weights loaded")

    bi = int(opt['user_key'])
    if opt['model'] != '':
        outdir = os.path.dirname(opt['model']) + os.sep + 'view_%d' % bi
        if not os.path.exists(outdir):
            os.mkdir(outdir)
    #
    input_size = opt['input_size']
    idx = opt['part_idx']
    tri = box_face

    for i, data in enumerate(train_load, 0):
        print(i, '/',
              len(train_data) // opt['batch_size'])
        data2cuda(data)
        net.eval()
        zs = np.linspace(-3, 3, 20).astype(np.float32)
        xx = data[0][bi, idx].contiguous().view(1, -1)
        ox = data[0][bi, :].contiguous().view(1, -1)
        cat = data[1][bi]
        print(cat)
        #==================================
        ptsa, ptsb = parse(ox.data.cpu().numpy()[0, :])
        #==================================
        fig = plt.figure(figsize=(9.6, 4.8))
        ax = fig.add_subplot(1, 2, 1, projection='3d')
        ax.view_init(elev=20, azim=90)
        ax.set_aspect('equal', adjustable='box')
        ax.set_xlim([-1, 1])
        ax.set_ylim([-1, 1])
        ax.set_zlim([-1, 1])
        #
        ax.plot_trisurf(ptsa[..., 0],
                        ptsa[..., 2],
                        tri,
                        ptsa[..., 1],
                        color=(0, 0, 1, 0.1))
        ax.plot_trisurf(ptsb[..., 0],
                        ptsb[..., 2],
                        tri,
                        ptsb[..., 1],
                        color=(0, 1, 0, 0.1))

        ax = fig.add_subplot(1, 2, 2, projection='3d')
        ax.set_aspect('equal', adjustable='box')
        ax.set_xlim([-1, 1])
        ax.set_ylim([-1, 1])
        ax.set_zlim([-1, 1])
        #
        ax.plot_trisurf(ptsa[..., 0],
                        ptsa[..., 2],
                        tri,
                        ptsa[..., 1],
                        color=(0, 0, 1, 0.1))
        ax.plot_trisurf(ptsb[..., 0],
                        ptsb[..., 2],
                        tri,
                        ptsb[..., 1],
                        color=(0, 1, 0, 0.1))
        plt.savefig(os.path.join(outdir, '_%d_%s_input.png' % (i, cat)))
        plt.close(fig)
        outdiri = os.path.join(outdir, '_%d_%s_output' % (i, cat))
        if not os.path.exists(outdiri):
            os.mkdir(outdiri)
        for zi in range(opt['z_size']):
            for ri in range(20):
                fig = plt.figure(figsize=(9.6, 4.8))
                z = np.zeros([1, opt['z_size']], dtype=np.float32)
                z[0, zi] = zs[ri]
                z = torch.from_numpy(z).cuda()
                with torch.no_grad():
                    r = net.decode(xx, z)
                x = r.data.cpu().numpy()[0, :]
                ptsa, ptsb = parse(x)
                #==================
                ax = fig.add_subplot(1, 2, 1, projection='3d')
                ax.view_init(elev=20, azim=90)
                ax.set_aspect('equal', adjustable='box')
                ax.set_xlim([-1, 1])
                ax.set_ylim([-1, 1])
                ax.set_zlim([-1, 1])
                #
                ax.plot_trisurf(ptsa[..., 0],
                                ptsa[..., 2],
                                tri,
                                ptsa[..., 1],
                                color=(0, 0, 1, 0.1))
                ax.plot_trisurf(ptsb[..., 0],
                                ptsb[..., 2],
                                tri,
                                ptsb[..., 1],
                                color=(0, 1, 0, 0.1))
                #
                ax = fig.add_subplot(1, 2, 2, projection='3d')
                ax.set_aspect('equal', adjustable='box')
                ax.set_xlim([-1, 1])
                ax.set_ylim([-1, 1])
                ax.set_zlim([-1, 1])
                #
                ax.plot_trisurf(ptsa[..., 0],
                                ptsa[..., 2],
                                tri,
                                ptsa[..., 1],
                                color=(0, 0, 1, 0.1))
                ax.plot_trisurf(ptsb[..., 0],
                                ptsb[..., 2],
                                tri,
                                ptsb[..., 1],
                                color=(0, 1, 0, 0.1))
                plt.savefig(os.path.join(outdiri, '_%d_%f.png' % (zi, zs[ri])))
                plt.close(fig)
Esempio n. 6
0
def run(**kwargs):
    global iternum
    #get configuration
    try:
        config = importlib.import_module('config.' + kwargs['config'])
        opt = config.__dict__
        for k in kwargs.keys():
            if not kwargs[k] is None:
                opt[k] = kwargs[k]
    except Exception as e:
        print(e)
        traceback.print_exc()
        exit()
    #get network
    try:
        m = importlib.import_module('net.' + opt['touch_net'])
        touchnet = m.Net(**opt)
        #
        m = importlib.import_module('net.' + opt['box_net'])
        boxnet = m.Net(**opt)
        #
        m = importlib.import_module('net.' + opt['touchpt_net'])
        touchptnet = m.Net(**opt)
        #
        if torch.cuda.is_available():
            touchnet = touchnet.cuda()
            touchnet.eval()
            boxnet = boxnet.cuda()
            boxnet.eval()
            touchptnet = touchptnet.cuda()
            touchptnet.eval()
        #
        partial_restore(touchnet, opt['touch_model'])
        partial_restore(boxnet, opt['box_model'])
        partial_restore(touchptnet, opt['touchpt_model'])
    except Exception as e:
        print(e)
        traceback.print_exc()
        exit()
    #get dataset
    dpath = os.path.join(opt['data_path'], 'test')
    opath = './log/joinN'
    if not os.path.exists(opath):
        os.makedirs(opath)
    cat_lst = os.listdir(dpath)
    #cat_lst = ['Chair','Table'];
    eval_log = open('./log/joinN/' + opt['mode'] + '_eval.csv', 'w')
    eval_all_sum = 0.0
    eval_all_cnt = 0.0
    for cat in cat_lst:
        eval_cat_sum = 0.0
        eval_cat_cnt = 0.0
        cpath = os.path.join(opt['data_path'], 'test', cat)
        copath = os.path.join(opath, 'joinN_' + cat + '_' + opt['mode'])
        if not os.path.exists(copath):
            os.mkdir(copath)
        slst = os.listdir(cpath)
        for f in slst:
            id = os.path.basename(f).split('.')[-2]
            fopath = os.path.join(copath, '_' + id)
            if not os.path.exists(fopath):
                os.mkdir(fopath)
            logf = open(os.path.join(fopath, '_000_000_log.txt'), 'w')
            h5f = h5py.File(os.path.join(cpath, f), 'r')
            img = np.array(h5f['img'])
            msk = np.array(h5f['msk'])
            smsk = np.array(h5f['smsk'])
            box = np.array(h5f['box'])
            num = box.shape[0]
            bdata = []
            msk_lst = []
            imgs_lst = []
            imgt_lst = []
            smsk_lst = []
            tmsk_lst = []
            box_lst = []
            #for each part
            rate = opt['user_rate']
            for i in range(num):
                msk_rate = (np.sum(msk[i, ...]) / np.sum(smsk[i, ...]))
                if msk_rate > rate:
                    msk_lst.append(msk[i, ...])
                    imgs, mski = msk_center(img, msk[i, ...])
                    imgs_lst.append(imgs)
                    smsk_lst.append(mski)
                    imgt_lst.append(imgs)
                    tmsk_lst.append(mski)
                    box_lst.append(box[i, ...])
                    im = Image.fromarray((imgs_lst[-1] * 255).astype(np.uint8),
                                         'RGB')
                    im.save(
                        os.path.join(fopath,
                                     '_000_%03d_img.png' % (len(imgs_lst))))
                    imt = Image.fromarray(
                        (smsk_lst[-1] * 255).astype(np.uint8), 'L')
                    imt.save(
                        os.path.join(fopath,
                                     '_001_%03d_msk.png' % (len(smsk_lst))))
            #
            imgs = np.stack(imgs_lst, axis=0)
            smsk = np.stack(smsk_lst, axis=0)
            imgt = np.stack(imgt_lst, axis=0)
            tmsk = np.stack(tmsk_lst, axis=0)
            bdata.append(torch.from_numpy(imgs).cuda())
            bdata.append(torch.from_numpy(smsk).cuda())
            bdata.append(torch.from_numpy(imgt).cuda())
            bdata.append(torch.from_numpy(tmsk).cuda())
            with torch.no_grad():
                boxout = boxnet(bdata)
            #
            size = np.sum(np.sum(bdata[1].data.cpu().numpy() > 0, axis=1),
                          axis=1)
            print(size)
            undone_queue = []
            for idx, v in enumerate(size):
                heapq.heappush(undone_queue, (-v, idx))
            done_queue = []
            msk_in = []
            box_out = []
            box_color = []
            box_gt = []
            s_gt = []
            s_out = []
            t_gt = []
            t_out = []
            r1_gt = []
            r1_out = []
            r2_gt = []
            r2_out = []
            baset = np.zeros([len(box_lst), 3], dtype=np.float32)
            bases = np.zeros([len(box_lst), 1], dtype=np.float32)
            while len(undone_queue) > 0:
                if len(done_queue) > 0:
                    itop = heapq.heappop(done_queue)
                    ci = itop[1]
                    bt = baset[ci, :][np.newaxis, :]
                    print('[done:%d' % ci, file=logf, end='')
                else:
                    itop = heapq.heappop(undone_queue)
                    ci = itop[1]
                    bo = boxout['sb'].data.cpu().numpy()[ci, ...]
                    bgt, rs, t = parsegt(box_lst[ci])
                    add_gt_for_eval(box_lst[ci], s_gt, t_gt, r1_gt, r2_gt)
                    add_out_for_eval(
                        boxout['ss'].data[ci, ...] * rs,
                        torch.from_numpy(t.astype(np.float32)).cuda(),
                        boxout['sr1'].data[ci, ...], boxout['sr2'].data[ci,
                                                                        ...],
                        s_out, t_out, r1_out, r2_out)
                    baset[ci, :] = t[0, :]
                    bases[ci, :] = rs
                    bt = t
                    box_out.append(bo * rs + bt)
                    box_gt.append(bgt)
                    box_color.append(red_box)
                    msk_in.append(msk_lst[ci])
                    print('[undone:%d' % ci, file=logf, end='')
                unfinish = []
                while len(undone_queue) > 0:
                    tdata = []
                    tptdata = []
                    jtop = heapq.heappop(undone_queue)
                    cj = jtop[1]
                    pimg, psmsk, ptmsk = msk_pair_center(
                        img, msk_lst[ci], msk_lst[cj])
                    pimg = torch.from_numpy(pimg).cuda().unsqueeze(0)
                    psmsk = torch.from_numpy(psmsk).cuda().unsqueeze(0)
                    ptmsk = torch.from_numpy(ptmsk).cuda().unsqueeze(0)
                    tdata.append(pimg)
                    tdata.append(psmsk)
                    tdata.append(ptmsk)
                    with torch.no_grad():
                        touchout = touchnet(tdata)
                    if touchout['y'].data.cpu().numpy()[0][0] > 0.5:
                        im = Image.fromarray((pimg.data.cpu().numpy()[0, ...] *
                                              255).astype(np.uint8))
                        im.save(
                            os.path.join(fopath,
                                         '_003_%03d_img.png' % (len(box_out))))
                        allmsk = (psmsk + ptmsk).data.cpu().numpy()[0, ...]
                        imt = Image.fromarray((allmsk * 255).astype(np.uint8))
                        imt.save(
                            os.path.join(fopath,
                                         '_004_%03d_msk.png' % (len(box_out))))
                        bgt, rsgt, _ = parsegt(box_lst[cj])
                        tptdata.append(pimg)
                        tptdata.append(psmsk)
                        tptdata.append(ptmsk)
                        tptdata.append(None)
                        vec = np.zeros([1, 21], dtype=np.float32)
                        rs = bases[ci, :]
                        vec[0, :3] = rs * boxout['ss'].data.cpu().numpy()[ci,
                                                                          ...]
                        vec[0, 3:6] = boxout['sr1'].data.cpu().numpy()[ci, ...]
                        vec[0, 6:9] = boxout['sr2'].data.cpu().numpy()[ci, ...]
                        vec[0,
                            9:12] = rs * boxout['ss'].data.cpu().numpy()[cj,
                                                                         ...]
                        vec[0, 15:18] = boxout['sr1'].data.cpu().numpy()[cj,
                                                                         ...]
                        vec[0, 18:21] = boxout['sr2'].data.cpu().numpy()[cj,
                                                                         ...]
                        tptdata.append(torch.from_numpy(vec).cuda())
                        with torch.no_grad():
                            tptout = touchptnet(tptdata)
                        heapq.heappush(done_queue, jtop)
                        t = tptout['t'].data.cpu().numpy()
                        bo = tptout['tb'].data.cpu().numpy()[0, ...]
                        box_out.append(bo + bt)
                        baset[cj, :] = (t + bt)[0, :]
                        sout = boxout['ss'].data.cpu().numpy()[
                            cj, ...] * tptout['ts'].data.cpu().numpy()[0, ...]
                        bases[cj, :] = np.max(np.abs(sout))
                        box_gt.append(bgt)
                        box_color.append(blue_box)
                        msk_in.append(bdata[1].data.cpu().numpy()[cj, ...])
                        add_gt_for_eval(box_lst[cj], s_gt, t_gt, r1_gt, r2_gt)
                        add_out_for_eval(
                            boxout['ss'].data[cj, ...] *
                            tptout['ts'].data[0, ...],
                            torch.from_numpy(
                                (t + bt).astype(np.float32)).cuda(),
                            boxout['sr1'].data[cj,
                                               ...], boxout['sr2'].data[cj,
                                                                        ...],
                            s_out, t_out, r1_out, r2_out)
                        print('->%d' % cj, file=logf, end='')
                    else:
                        unfinish.append(jtop)
                for p in unfinish:
                    heapq.heappush(undone_queue, p)
                print(']', file=logf)
            im = Image.fromarray((img * 255).astype(np.uint8))
            im.save(os.path.join(fopath, '_001_000_im.png'))
            writegt(fopath, box_gt)
            writeout(fopath, box_out, box_color, msk_in)
            val = eval(s_out, t_out, r1_out, r2_out, s_gt, t_gt, r1_gt, r2_gt)
            v = float(val.data.cpu().numpy())
            json.dump({'bcd': v}, open(os.path.join(fopath, 'meta.json'), 'w'))
            if cat != 'TrainChair':
                eval_all_sum += val.data.cpu().numpy()
                eval_all_cnt += 1.0
                eval_cat_sum += val.data.cpu().numpy()
                eval_cat_cnt += 1.0
        if cat != 'TrainChair':
            print(cat + ',%f' % (eval_cat_sum / eval_cat_cnt), file=eval_log)
            eval_log.flush()
    print('inst mean,%f' % (eval_all_sum / eval_all_cnt), file=eval_log)
    eval_log.close()
Esempio n. 7
0
def run(**kwargs):
    global iternum;
    #get configuration
    try:
        config = importlib.import_module('config.'+kwargs['config']);
        opt = config.__dict__;
        for k in kwargs.keys():
            if not kwargs[k] is None:
                opt[k] = kwargs[k];
    except Exception as e:
        print(e);
        traceback.print_exc();
        exit();
    #get network
    try:
        m = importlib.import_module('net.'+opt['touch_net']);
        touchnet = m.Net(**opt);
        #
        m = importlib.import_module('net.'+opt['box_net']);
        boxnet = m.Net(**opt);
        #
        m = importlib.import_module('net.'+opt['touchpt_net']);
        touchptnet = m.Net(**opt);
        #
        if torch.cuda.is_available():
            touchnet = touchnet.cuda();
            touchnet.eval();
            boxnet = boxnet.cuda();
            boxnet.eval();
            touchptnet = touchptnet.cuda();
            touchptnet.eval();
        #
        partial_restore(touchnet,opt['touch_model']);
        partial_restore(boxnet,opt['box_model']);
        partial_restore(touchptnet,opt['touchpt_model']);
    except Exception as e:
        print(e);
        traceback.print_exc();
        exit();
    #get dataset
    dpath = opt['data_path'];
    im_lst = os.listdir(dpath);
    for im in im_lst:
        if im.endswith('_im.png'):
            image = np.array(Image.open(os.path.join(dpath,im))).astype(np.float32) / 255.0;
            bdata = [];
            img_lst = [];
            nm_lst = [];
            smsk_lst = [];
            tmsk_lst = [];
            for msk in im_lst:
                if msk.endswith('_msk.png') and (im.split('_')[0] in msk):
                    nm_lst.append(msk);
                    img_lst.append(image);
                    mskimg = np.array(Image.open(os.path.join(dpath,msk))).astype(np.float32) / 255.0;
                    smsk_lst.append(mskimg);
                    tmsk_lst.append(mskimg);
                    print(mskimg.shape);
            img = np.stack(img_lst,axis=0);
            smsk = np.stack(smsk_lst,axis=0);
            tmsk = np.stack(tmsk_lst,axis=0);
            bdata.append(torch.from_numpy(img).cuda());
            bdata.append(torch.from_numpy(smsk).cuda());
            bdata.append(torch.from_numpy(tmsk).cuda());
            with torch.no_grad():
                boxout = boxnet(bdata);
            bx = boxout['sb'].data.cpu().numpy();
            for i,nm in enumerate(nm_lst):
                writebox(os.path.join(dpath,nm.replace('_msk.png','_box.ply')),[bx[i,...]]);