Example #1
0
def run(**kwargs):
    from .g.box import bv
    from .g.box import box_face
    from matplotlib import pyplot as plt
    from matplotlib import animation
    from mpl_toolkits.mplot3d import axes3d as p3
    from util.data.ply import write_ply
    import pandas as pd
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.axis('equal')
    bv = torch.from_numpy(bv)
    r = torch.randn([1, 4])
    r = r.repeat(8, 1)
    r = r / torch.sqrt(torch.sum(r**2, dim=1, keepdim=True))
    bv = qrot(r, bv)
    ax.scatter(bv[..., 0], bv[..., 1], bv[..., 2])
    plt.show()
    fidx = box_face
    T = np.dtype([("n", np.uint8), ("i0", np.int32), ('i1', np.int32),
                  ('i2', np.int32)])
    face = np.zeros(shape=[fidx.shape[0]], dtype=T)
    for fi in range(fidx.shape[0]):
        face[fi] = (3, fidx[fi, 0], fidx[fi, 1], fidx[fi, 2])
    write_ply('./a.ply',
              points=pd.DataFrame(bv.cpu().numpy()),
              faces=pd.DataFrame(face),
              as_text=True)
    plt.show()
Example #2
0
File: bcd.py Project: samhu1989/PON
def write_box(prefix, box):
    T = np.dtype([("n", np.uint8), ("i0", np.int32), ('i1', np.int32),
                  ('i2', np.int32)])
    fidx = box_face
    face = np.zeros(shape=[fidx.shape[0]], dtype=T)
    for i in range(fidx.shape[0]):
        face[i] = (3, fidx[i, 0], fidx[i, 1], fidx[i, 2])
    for i in range(box.shape[0]):
        write_ply(prefix + '_%d.ply' % i,
                  points=pd.DataFrame(box[i, ...]),
                  faces=pd.DataFrame(face))
Example #3
0
def write_box(box,path):
    obbp = [];
    obbf = [];
    for i in range(box.shape[0]):
        vec = box[i,...];
        obbp.append( OBB.v2points(vec) );
        obbf.append(bf + i*8);
    obbv = np.concatenate(obbp,axis=0);
    fidx = np.concatenate(obbf,axis=0);
    T=np.dtype([("n",np.uint8),("i0",np.int32),('i1',np.int32),('i2',np.int32)]);
    face = np.zeros(shape=[12*len(obbf)],dtype=T);
    for i in range(fidx.shape[0]):
        face[i] = (3,fidx[i,0],fidx[i,1],fidx[i,2]);
    write_ply(path,points=pd.DataFrame(obbv.astype(np.float32)),faces=pd.DataFrame(face));  
Example #4
0
def writebox(path,box,colors=None):
    fidx = box_face;
    T=np.dtype([("n",np.uint8),("i0",np.int32),('i1',np.int32),('i2',np.int32)]);
    bn = len(box);
    face = np.zeros(shape=[bn*fidx.shape[0]],dtype=T);
    for i in range(bn*fidx.shape[0]):
        nn = i // fidx.shape[0];
        ni = i % fidx.shape[0];
        face[i] = (3,fidx[ni,0]+nn*8,fidx[ni,1]+nn*8,fidx[ni,2]+nn*8);
    pts = np.concatenate(box,axis=0);
    if colors is None:
        write_ply(path,points=pd.DataFrame(pts.astype(np.float32)),faces=pd.DataFrame(face));
    else:
        colors = np.concatenate(colors,axis=0);
        pointsc = pd.concat([pd.DataFrame(pts.astype(np.float32)),pd.DataFrame(colors)],axis=1,ignore_index=True);
        write_ply(path,points=pointsc,faces=pd.DataFrame(face),color=True);
Example #5
0
File: box.py Project: samhu1989/PON
def run(**kwargs):
    from .box import bv
    from matplotlib import pyplot as plt
    from matplotlib import animation
    from mpl_toolkits.mplot3d import axes3d as p3
    import pandas as pd
    from util.data.ply import write_ply
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.axis('equal')
    bv = torch.from_numpy(bv).view(1, 8, 3)
    bv = bv.expand(2, 8, 3).contiguous()
    sr = torch.randn([2, 9])
    ax.scatter(bv[0, ..., 0], bv[0, ..., 1], bv[0, ..., 2])
    fidx = box_face
    T = np.dtype([("n", np.uint8), ("i0", np.int32), ('i1', np.int32),
                  ('i2', np.int32)])
    face = np.zeros(shape=[fidx.shape[0]], dtype=T)
    for fi in range(fidx.shape[0]):
        face[fi] = (3, fidx[fi, 0], fidx[fi, 1], fidx[fi, 2])
    write_ply('./a.ply',
              points=pd.DataFrame(bv[0, ...].cpu().numpy()),
              faces=pd.DataFrame(face))
    plt.show()
Example #6
0
def run(**kwargs):
    global iternum
    #get configuration
    try:
        config = importlib.import_module('config.' + kwargs['config'])
        opt = config.__dict__
        for k in kwargs.keys():
            if not kwargs[k] is None:
                opt[k] = kwargs[k]
    except Exception as e:
        print(e)
        traceback.print_exc()
        exit()
    iternum = opt['nepoch']
    #get network
    try:
        m = importlib.import_module('net.' + opt['net'])
        net = m.Net(**opt)
        if torch.cuda.is_available():
            net = net.cuda()
    except Exception as e:
        print(e)
        traceback.print_exc()
        exit()
    #get dataset
    try:
        m = importlib.import_module('util.dataset.' + opt['dataset'])
        train_data = m.Data(opt, 'train')
        val_data = m.Data(opt, 'test')
        train_load = DataLoader(train_data,
                                batch_size=opt['batch_size'],
                                shuffle=True,
                                num_workers=opt['workers'])
        val_load = DataLoader(val_data,
                              batch_size=opt['batch_size'],
                              shuffle=False,
                              num_workers=opt['workers'])
    except Exception as e:
        print(e)
        traceback.print_exc()
        exit()

    print('bs:', opt['batch_size'])

    if opt['model'] != '':
        partial_restore(net, opt['model'])
        print("Previous weights loaded")

    if opt['model'] != '':
        outdir = os.path.dirname(opt['model']) + os.sep + 'view'
        if not os.path.exists(outdir):
            os.mkdir(outdir)

    #run the code
    optim = eval('optim.' + opt['optim'])(config.parameters(net),
                                          lr=opt['lr'],
                                          weight_decay=opt['weight_decay'])
    tri = box_face
    fidx = tri

    T = np.dtype([("n", np.uint8), ("i0", np.int32), ('i1', np.int32),
                  ('i2', np.int32)])
    face = np.zeros(shape=[2 * fidx.shape[0]], dtype=T)
    for i in range(2 * fidx.shape[0]):
        if i < fidx.shape[0]:
            face[i] = (3, fidx[i, 0], fidx[i, 1], fidx[i, 2])
        else:
            face[i] = (3, fidx[i - fidx.shape[0], 0] + 8,
                       fidx[i - fidx.shape[0], 1] + 8,
                       fidx[i - fidx.shape[0], 2] + 8)

    for i, data in enumerate(val_load, 0):
        print(i, '/', len(val_data))
        data2cuda(data)
        img = data[0].data.cpu().numpy()
        msks = data[1].data.cpu().numpy()
        mskt = data[2].data.cpu().numpy()
        ygt = data[3].data.cpu().numpy()
        vgt = data[4].data.cpu().numpy()
        sgt = data[-3].data.cpu().numpy()
        net.eval()
        with torch.no_grad():
            out = net(data)
        acc = config.accuracy(data, out)
        id = data[-2]
        tb = out['tb'].data.cpu().numpy()
        sb = out['sb'].data.cpu().numpy()
        #y = out['y'].data.cpu().numpy();
        cat = data[-1]
        err = acc['t'].data.cpu().numpy()
        #
        for tagi, tag in enumerate(id):
            cpath = os.path.join(
                outdir, '_' + cat[tagi] + '_%04d' % i + '_%02d' % tagi)
            if not os.path.exists(cpath):
                os.mkdir(cpath)
            im = Image.fromarray((img[tagi, ...] * 255).astype(np.uint8),
                                 mode='RGB')
            im.save(os.path.join(cpath, '_000_img.png'))
            if opt['mode'] == 'full':
                im.save(os.path.join(cpath, '_000_input.png'))
            elif opt['mode'] == 'part':
                iim = img[tagi, ...] * (
                    (msks[tagi, ...] + mskt[tagi, ...]).reshape(224, 224, 1))
                iim = Image.fromarray((iim * 255).astype(np.uint8), mode='RGB')
                iim.save(os.path.join(cpath, '_000_input.png'))
            mks = Image.fromarray((msks[tagi, ...] * 255).astype(np.uint8),
                                  mode='L')
            mks.save(os.path.join(cpath, '_000_msks.png'))
            mkt = Image.fromarray((mskt[tagi, ...] * 255).astype(np.uint8),
                                  mode='L')
            mkt.save(os.path.join(cpath, '_000_mskt.png'))
            ptsa, ptsb = parse(vgt[tagi, ...])
            center = vgt[tagi, 12:15]
            print('L2:%f' % (err[tagi, ...]),
                  file=open(os.path.join(cpath, '_000_log.txt'), 'w'))
            print('id:%s' % (id[tagi]),
                  file=open(os.path.join(cpath, '_000_meta.txt'), 'w'))
            st = sgt[tagi, :]
            st = st[np.newaxis, :]
            ptgt = np.concatenate([ptsa, ptsb], axis=0)
            ptgt += st
            write_ply(os.path.join(cpath, '_000_gt.ply'),
                      points=pd.DataFrame(ptgt),
                      faces=pd.DataFrame(face))
            ptout = np.concatenate([sb[tagi, ...], tb[tagi, ...]], axis=0)
            ptout += st
            write_ply(os.path.join(cpath, '_000_out.ply'),
                      points=pd.DataFrame(ptout),
                      faces=pd.DataFrame(face))
Example #7
0
def run(**kwargs):
    #get configuration
    try:
        config = importlib.import_module('config.' + kwargs['config'])
        opt = config.__dict__
        for k in kwargs.keys():
            if not kwargs[k] is None:
                opt[k] = kwargs[k]
    except Exception as e:
        print(e)
        traceback.print_exc()
        exit()
    #get network
    try:
        m = importlib.import_module('net.' + opt['net'])
        net = m.Net(**opt)
        if torch.cuda.is_available():
            net = net.cuda()
    except Exception as e:
        print(e)
        traceback.print_exc()
        exit()
    #get dataset
    try:
        m = importlib.import_module('util.dataset.' + opt['dataset'])
        train_data = m.Data(opt, True)
        val_data = m.Data(opt, False)
        train_load = DataLoader(train_data,
                                batch_size=opt['batch_size'],
                                shuffle=True,
                                num_workers=opt['workers'])
        val_load = DataLoader(val_data,
                              batch_size=opt['batch_size'] * 2,
                              shuffle=False,
                              num_workers=opt['workers'])
    except Exception as e:
        print(e)
        traceback.print_exc()
        exit()

    #run the code

    #load pre-trained
    if opt['model'] != '':
        partial_restore(net, opt['model'])
        print("Previous weights loaded")
    #
    grid = randsphere2(opt['pts_num'])
    grid = torch.from_numpy(
        grid.reshape(1, grid.shape[0],
                     grid.shape[1]).astype(np.float32)).cuda()
    grid = grid.repeat(opt['batch_size'], 1, 1)
    hull = triangulateSphere(grid.cpu().data.numpy())
    fidx = hull[0].simplices
    T = np.dtype([("n", np.uint8), ("i0", np.int32), ('i1', np.int32),
                  ('i2', np.int32)])
    face = np.zeros(shape=[fidx.shape[0]], dtype=T)
    for i in range(fidx.shape[0]):
        face[i] = (3, fidx[i, 0], fidx[i, 1], fidx[i, 2])
    #
    ws = [1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.0]
    done_cat = {}
    for i, data in enumerate(val_load, 0):
        if data[3][0] in done_cat.keys():
            continue
        else:
            done_cat[data[3][0]] = data
        img = data[0].data.cpu().numpy()
        img = img.transpose((0, 2, 3, 1))
        ygt = data[1].data.cpu().numpy()
        for j in range(opt['batch_size'] * 2):
            cat = data[3][j]
            opath = './log/interp/' + cat
            if not os.path.exists(opath):
                os.mkdir(opath)
            Image.fromarray((img[j, ...] * 255).astype(
                np.uint8)).save(opath + '/%s.png' % (data[4][j]))
            write_pts2sphere(opath + '/%s_gt.ply' % (data[4][j]), ygt[j, :, :])

        net.eval()
        for w in ws:
            with torch.no_grad():
                data2cuda(data)
                out = net.interp(
                    data[0][0:opt['batch_size'], ...],
                    data[0][opt['batch_size']:2 * opt['batch_size'],
                            ...], w, grid)

            yout = out['y'].data.cpu().numpy()
            img = img.transpose((0, 2, 3, 1))
            for j in range(opt['batch_size']):
                cat = data[3][j]
                opath = './log/interp/' + cat
                if not os.path.exists(opath):
                    os.mkdir(opath)
                write_ply(opath + '/%s_%s_%f.ply' %
                          (data[4][j], data[4][j + opt['batch_size']], w),
                          points=pd.DataFrame(yout[j, :, :]),
                          faces=pd.DataFrame(face))
    klst = list(done_cat.keys())
    for ni, ci in enumerate(klst):
        datai = done_cat[ci]
        for nj in range(ni + 1, len(klst)):
            cj = klst[nj]
            dataj = done_cat[cj]
            imgi = datai[0].data.cpu().numpy()
            imgi = imgi.transpose((0, 2, 3, 1))
            ygti = datai[1].data.cpu().numpy()
            #
            imgj = dataj[0].data.cpu().numpy()
            imgj = imgj.transpose((0, 2, 3, 1))
            ygtj = dataj[1].data.cpu().numpy()
            opath = './log/interp/' + ci + '_' + cj
            #
            if not os.path.exists(opath):
                os.mkdir(opath)
            for j in range(opt['batch_size']):
                Image.fromarray((imgi[j, ...] * 255).astype(
                    np.uint8)).save(opath + '/%s.png' % (datai[4][j]))
                write_pts2sphere(opath + '/%s_gt.ply' % (datai[4][j]),
                                 ygti[j, :, :])
                Image.fromarray((imgj[j, ...] * 255).astype(
                    np.uint8)).save(opath + '/%s.png' % (dataj[4][j]))
                write_pts2sphere(opath + '/%s_gt.ply' % (dataj[4][j]),
                                 ygtj[j, :, :])
            #
            net.eval()
            for w in ws:
                with torch.no_grad():
                    data2cuda(data)
                    out = net.interp(datai[0][0:opt['batch_size'], ...],
                                     dataj[0][0:opt['batch_size'],
                                              ...], w, grid)
                yout = out['y'].data.cpu().numpy()
                for j in range(opt['batch_size']):
                    write_ply(opath + '/%s_%s_%f.ply' %
                              (datai[4][j], dataj[4][j], w),
                              points=pd.DataFrame(yout[j, :, :]),
                              faces=pd.DataFrame(face))
Example #8
0
def run(**kwargs):
    global iternum
    #get configuration
    try:
        config = importlib.import_module('config.' + kwargs['config'])
        opt = config.__dict__
        for k in kwargs.keys():
            if not kwargs[k] is None:
                opt[k] = kwargs[k]
    except Exception as e:
        print(e)
        traceback.print_exc()
        exit()
    iternum = opt['nepoch']
    #get network
    try:
        m = importlib.import_module('net.' + opt['net'])
        net = m.Net(**opt)
        if torch.cuda.is_available():
            net = net.cuda()
    except Exception as e:
        print(e)
        traceback.print_exc()
        exit()
    #get dataset
    try:
        m = importlib.import_module('util.dataset.' + opt['dataset'])
        train_data = m.Data(opt, 'train')
        val_data = m.Data(opt, opt['user_key'])
        train_load = DataLoader(train_data,
                                batch_size=opt['batch_size'],
                                shuffle=True,
                                num_workers=opt['workers'])
        val_load = DataLoader(val_data,
                              batch_size=opt['batch_size'],
                              shuffle=False,
                              num_workers=opt['workers'])
    except Exception as e:
        print(e)
        traceback.print_exc()
        exit()

    for i, data in enumerate(train_load, 0):
        data2cuda(data)
        d = data
        break
    #
    img = data[0].data.cpu().numpy()
    msks = data[1].data.cpu().numpy()
    mskt = data[2].data.cpu().numpy()
    vgt = data[4].data.cpu().numpy()
    #run the code
    optim = eval('optim.' + opt['optim'])(config.parameters(net),
                                          lr=opt['lr'],
                                          weight_decay=opt['weight_decay'])
    tri = box_face
    fidx = tri

    T = np.dtype([("n", np.uint8), ("i0", np.int32), ('i1', np.int32),
                  ('i2', np.int32)])
    face = np.zeros(shape=[fidx.shape[0]], dtype=T)
    for i in range(fidx.shape[0]):
        face[i] = (3, fidx[i, 0], fidx[i, 1], fidx[i, 2])

    debug_path = './log/debug_touchpt'

    for iteri in range(opt['nepoch']):
        net.train()
        out = net(data)
        loss = config.loss(data, out)
        optim.zero_grad()
        loss['overall'].backward()
        optim.step()
        net.eval()
        with torch.no_grad():
            out = net(data)
        acc = config.accuracy(data, out)
        print('iteri:', iteri)
        for k, v in acc.items():
            print(k, ':', v)
        id = data[-2]
        if not os.path.exists(debug_path):
            os.mkdir(debug_path)
        tb = out['tb'].data.cpu().numpy()
        sb = out['sb'].data.cpu().numpy()
        #
        for tagi, tag in enumerate(id):
            cpath = os.path.join(debug_path, tag)
            if not os.path.exists(cpath):
                os.mkdir(cpath)
            if iteri == 0:
                im = Image.fromarray((img[tagi, ...] * 255).astype(np.uint8),
                                     mode='RGB')
                im.save(os.path.join(cpath, 'input.png'))
                mks = Image.fromarray((msks[tagi, ...] * 255).astype(np.uint8),
                                      mode='L')
                mks.save(os.path.join(cpath, 'msks.png'))
                mkt = Image.fromarray((mskt[tagi, ...] * 255).astype(np.uint8),
                                      mode='L')
                mkt.save(os.path.join(cpath, 'mskt.png'))
                ptsa, ptsb = parse(vgt[tagi, ...])
                write_ply(os.path.join(cpath, 'gta.ply'),
                          points=pd.DataFrame(ptsa),
                          faces=pd.DataFrame(face))
                write_ply(os.path.join(cpath, 'gtb.ply'),
                          points=pd.DataFrame(ptsb),
                          faces=pd.DataFrame(face))
            write_ply(os.path.join(cpath, 'a_%d.ply' % iteri),
                      points=pd.DataFrame(sb[tagi, ...]),
                      faces=pd.DataFrame(face))
            write_ply(os.path.join(cpath, 'b_%d.ply' % iteri),
                      points=pd.DataFrame(tb[tagi, ...]),
                      faces=pd.DataFrame(face))
Example #9
0
def writelog(**kwargs):
    global best
    global bestn
    opt = kwargs['opt']
    iepoch = kwargs['iepoch']
    nepoch = opt['nepoch']
    ib = kwargs['idata']
    nb = kwargs['ndata'] // opt['batch_size']
    out = kwargs['out']
    net = kwargs['net']
    data = kwargs['data']
    meter = kwargs['meter']
    print('[' + str(datetime.now()) + '][%d/%d,%d/%d]' %
          (iepoch, nepoch, ib, nb) + 'training:' + str(kwargs['istraining']))
    if not 'log_tmp' in opt.keys():
        opt['log_tmp'] = opt['log'] + os.sep + opt['net'] + '_' + opt[
            'config'] + '_' + opt['mode'] + '_' + str(datetime.now()).replace(
                ' ', '-').replace(':', '-')
        os.mkdir(opt['log_tmp'])
        with open(opt['log_tmp'] + os.sep + 'options.json', 'w') as f:
            json.dump(opt, f, cls=NpEncoder)
        nparam = 0
        with open(opt['log_tmp'] + os.sep + 'net.txt', 'w') as logtxt:
            print(str(kwargs['net']), file=logtxt)
            for p in parameters(kwargs['net']):
                nparam += torch.numel(p)
            print('nparam:%d' % nparam, file=logtxt)

    with open(opt['log_tmp'] + os.sep + 'log.txt', 'a') as logtxt:
        print('[' + str(datetime.now()) + '][%d/%d,%d/%d]' %
              (iepoch, nepoch, ib, nb) + 'training:' +
              str(kwargs['istraining']),
              file=logtxt)
        print(json.dumps(meter, cls=NpEncoder), file=logtxt)

    if not kwargs['istraining'] and ib >= nb - 1:
        if meter['cd'].overall_meter.avg < best[-1]:
            fn = bestn[-1]
            if fn:
                os.remove(opt['log_tmp'] + os.sep + fn)
            fn = 'net_' + str(datetime.now()).replace(' ', '-').replace(
                ':', '-') + '.pth'
            best[-1] = meter['cd'].overall_meter.avg
            bestn[-1] = fn
            torch.save(net.state_dict(), opt['log_tmp'] + os.sep + fn)
            idx = np.argsort(best)
            best = best[idx]
            bestn = [bestn[x] for x in idx.tolist()]
            bestdict = dict(zip(bestn, best.tolist()))
            print(best)
            print(bestn)
            print(bestdict)
            with open(opt['log_tmp'] + os.sep + 'best.json', 'w') as f:
                json.dump(bestdict, f)

    if opt['ply'] and not kwargs['istraining']:
        ply_path = opt['log_tmp'] + os.sep + 'ply'
        if not os.path.exists(ply_path):
            os.mkdir(ply_path)
        x = out['grid_x']
        x = x.data.cpu().numpy()
        y = out['y']
        yout = y.data.cpu().numpy()
        ysrc = data[3]
        ysrc = ysrc.data.cpu().numpy()
        ytgt = data[4]
        ytgt = ytgt.data.cpu().numpy()
        yall = data[5]
        yall = yall.data.cpu().numpy()
        cat = data[-1]
        im = data[0]
        im = im.data.cpu().numpy()
        src = data[1]
        src = src.data.cpu().numpy()
        tgt = data[2]
        tgt = tgt.data.cpu().numpy()
        for i in range(y.shape[0]):
            fidx = repeat_face(x[i, ...], opt['grid_num'], 8)
            T = np.dtype([("n", np.uint8), ("i0", np.int32), ('i1', np.int32),
                          ('i2', np.int32)])
            face = np.zeros(shape=[fidx.shape[0]], dtype=T)
            for fi in range(fidx.shape[0]):
                face[fi] = (3, fidx[fi, 0], fidx[fi, 1], fidx[fi, 2])
            write_ply(ply_path + os.sep + '_%04d_%03d_%s_allgt.ply' %
                      (ib, i, cat[i]),
                      points=pd.DataFrame(rotyup(yall[i, ...])),
                      faces=pd.DataFrame(face),
                      as_text=opt['as_text'])
            write_ply(ply_path + os.sep + '_%04d_%03d_%s_src.ply' %
                      (ib, i, cat[i]),
                      points=pd.DataFrame(rotyup(ysrc[i, ...])),
                      faces=pd.DataFrame(face[0:12]),
                      as_text=opt['as_text'])
            write_ply(ply_path + os.sep + '_%04d_%03d_%s_tgt_gt.ply' %
                      (ib, i, cat[i]),
                      points=pd.DataFrame(rotyup(ytgt[i, ...])),
                      faces=pd.DataFrame(face[0:12]),
                      as_text=opt['as_text'])
            write_ply(ply_path + os.sep + '_%04d_%03d_%s_tgt_out.ply' %
                      (ib, i, cat[i]),
                      points=pd.DataFrame(rotyup(yout[i, ...])),
                      faces=pd.DataFrame(face[0:12]),
                      as_text=opt['as_text'])
            write_ply(ply_path + os.sep + '_%04d_%03d_%s_allout.ply' %
                      (ib, i, cat[i]),
                      points=pd.DataFrame(
                          rotyup(
                              np.concatenate([ysrc[i, ...], yout[i, ...]],
                                             axis=0))),
                      faces=pd.DataFrame(face),
                      as_text=opt['as_text'])
            img = im[i, ...]
            img = img.transpose((1, 2, 0))
            img = Image.fromarray(np.uint8(255.0 * img))
            img.save(ply_path + os.sep + '_%04d_%03d_%s_input.png' %
                     (ib, i, cat[i]))
            img = src[i, ...]
            img = Image.fromarray(np.uint8(255.0 * img))
            img.save(ply_path + os.sep + '_%04d_%03d_%s_src_msk.png' %
                     (ib, i, cat[i]))
            img = tgt[i, ...]
            img = Image.fromarray(np.uint8(255.0 * img))
            img.save(ply_path + os.sep + '_%04d_%03d_%s_tgt_msk.png' %
                     (ib, i, cat[i]))
Example #10
0
def run(**kwargs):
    #get configuration
    try:
        config = importlib.import_module('config.' + kwargs['config'])
        opt = config.__dict__
        for k in kwargs.keys():
            if not kwargs[k] is None:
                opt[k] = kwargs[k]
    except Exception as e:
        print(e)
        traceback.print_exc()
        exit()
    #get network
    try:
        m = importlib.import_module('net.' + opt['net'])
        net = m.Net(**opt)
        if torch.cuda.is_available():
            net = net.cuda()
    except Exception as e:
        print(e)
        traceback.print_exc()
        exit()
    #get dataset
    try:
        m = importlib.import_module('util.dataset.' + opt['dataset'])
        train_data = m.Data(opt, True)
        val_data = m.Data(opt, False)
        train_load = DataLoader(train_data,
                                batch_size=opt['batch_size'],
                                shuffle=True,
                                num_workers=opt['workers'])
        val_load = DataLoader(val_data,
                              batch_size=opt['batch_size'],
                              shuffle=False,
                              num_workers=opt['workers'])
    except Exception as e:
        print(e)
        traceback.print_exc()
        exit()

    #run the code

    #load pre-trained
    if opt['model'] != '':
        partial_restore(net, opt['model'])
        print("Previous weights loaded")
    #
    grid = randsphere2(opt['pts_num'])
    grid = torch.from_numpy(
        grid.reshape(1, grid.shape[0],
                     grid.shape[1]).astype(np.float32)).cuda()
    grid = grid.repeat(opt['batch_size'], 1, 1)
    hull = triangulateSphere(grid.cpu().data.numpy())
    fidx = hull[0].simplices
    T = np.dtype([("n", np.uint8), ("i0", np.int32), ('i1', np.int32),
                  ('i2', np.int32)])
    face = np.zeros(shape=[fidx.shape[0]], dtype=T)
    for i in range(fidx.shape[0]):
        face[i] = (3, fidx[i, 0], fidx[i, 1], fidx[i, 2])

    #
    done_cat = []
    for i, data in enumerate(val_load, 0):
        if data[3][0] in done_cat:
            continue
        else:
            done_cat.append(data[3][0])
        net.eval()
        with torch.no_grad():
            data2cuda(data)
            out = net(data[0], grid)
        img = data[0].data.cpu().numpy()
        ygt = data[1].data.cpu().numpy()
        yout = out['y'].data.cpu().numpy()
        img = img.transpose((0, 2, 3, 1))
        for j in range(opt['batch_size']):
            cat = data[3][j]
            opath = './log/debug/' + cat
            if not os.path.exists(opath):
                os.mkdir(opath)
            Image.fromarray((img[j, :, :] * 255).astype(
                np.uint8)).save(opath + '/%s.png' % (data[4][j]))
            write_ply(opath + '/%s.ply' % (data[4][j]),
                      points=pd.DataFrame(yout[j, :, :]),
                      faces=pd.DataFrame(face))
            write_pts2sphere(opath + '/%s_gt.ply' % (data[4][j]), ygt[j, :, :])

        X = torch.FloatTensor(opt['batch_size'], 3, 224, 224)
        X.data.normal_(0, 0.5)
        X.requires_grad = True
        optimizer = eval('optim.' +
                         opt['optim'])([X],
                                       lr=opt['lr'],
                                       weight_decay=opt['weight_decay'])

        for iter in range(opt['nepoch']):
            optimizer.zero_grad()
            out = net(torch.sigmoid(X.cuda()))
            dist1, dist2, _, _ = distChamfer(out['y'], data[1])
            loss = torch.mean(dist1) + torch.mean(dist2)
            loss.backward()
            print(iter, ':', loss.item())
            if loss.item() < 0.001:
                break
            optimizer.step()

        ximg = torch.sigmoid(X).data.numpy()
        ximg = ximg.transpose((0, 2, 3, 1))
        for j in range(opt['batch_size']):
            cat = data[3][j]
            opath = './log/debug/' + cat
            if not os.path.exists(opath):
                os.mkdir(opath)
            Image.fromarray((ximg[j, :, :] * 255).astype(
                np.uint8)).save(opath + '/%s_opt.png' % (data[4][j]))

        with torch.no_grad():
            data2cuda(data)
            out = net(torch.sigmoid(X.cuda()), grid)
        yout = out['y'].data.cpu().numpy()
        for j in range(opt['batch_size']):
            cat = data[3][j]
            opath = './log/debug/' + cat
            write_ply(opath + '/%s_opt.ply' % (data[4][j]),
                      points=pd.DataFrame(yout[j, :, :]),
                      faces=pd.DataFrame(face))