コード例 #1
0
def main():
    enter_debug()
    is_net3 = False
    with open('task.train.json','r') as fin:
        cfgs = json.load(fin)
    load_step = cfgs['load_step']
    net_name = cfgs['net_name']
    if net_name == 'SRNet3':
        is_net3 = True
    else:
        is_net3 = False
    if is_net3:
        net = nets.SRNet3(filenames=['data.sino8x.json', 'net.srnet1.json'], batch_size=2, low_shape=[16, 80], high_shape=[128, 640], nb_down_sample=2, load_step=load_step)
    else:
        net = nets.SRNet4(filenames=['data.sino8x.json', 'net.srnet1.json'], batch_size=2, low_shape=[16, 80], high_shape=[128, 640], nb_down_sample=2, load_step=load_step)
    # net = nets.SRNet3(filenames=['data.sino8x.json', 'net.srnet1.json'], load_step=-1)    
    net.init()
    # with datasets.SinoShep(filenames='data.sino8x.json') as dataset:
        # ss = dataset.sample()
    ss = np.load('to_sr.npz')
    nb_images = ss['data0'].shape[0]
    srs = []
    its = []
    for i in tqdm(range(nb_images//2)):
        # if is_net3:
        #     feed = ss
        # else:
        feed = {
            'data3': ss['data3'][2*i:2*i+2, :, :, :],
            'data': ss['data2'][2*i:2*i+2, :, :, :],
            'data2': ss['data2'][2*i:2*i+2, :, :, :],
            'data1': ss['data1'][2*i:2*i+2, :, :, :],
            'data0': ss['data0'][2*i:2*i+2, :, :, :],
            'label': ss['data0'][2*i:2*i+2, :, :, :]}        
        pred = net.predict(feed)             
        srs.append(pred['inference'])       
        its.append(pred['interp'])
        print(pred['inference'].shape)
    pred_sr = np.concatenate(srs, axis=0)
    pred_it = np.concatenate(its, axis=0)
    # pred = net.predict(ss)    
    np.save('sr.npy', pred_sr)
    np.save('it.npy', pred_it)
コード例 #2
0
ファイル: sr.py プロジェクト: Hong-Xiang/XLearning
matplotlib.use('agg')
import matplotlib.pyplot as plt

from xlearn.dataset.mnist import MNISTImage, MNIST2
from xlearn.dataset.sinogram import Sinograms
from xlearn.dataset.flickr import Flickr25k

from xlearn.nets.super_resolution import SRNetInterp, SRSimple, SRF3D, SRClassic

from xlearn.utils.general import with_config, empty_list, enter_debug
from xlearn.utils.image import subplot_images

c = dict()
c['is_cl'] = False

enter_debug()


def init():
    if c['net'] is not None:
        net = define_net()
    else:
        net = None
    if c['dataset'] is not None:
        dataset = define_dataset()
    else:
        dataset = None
    return net, dataset


def define_dataset():
コード例 #3
0
ファイル: main.py プロジェクト: Hong-Xiang/XLearning
def xln(cfg, debug):
    if debug:
        print("ENTER DEBUG MODE")
        enter_debug()
コード例 #4
0
def xln(config, cfg, debug):
    config.config = cfg
    config.load()
    if debug:
        print("ENTER DEBUG MODE")
        enter_debug()