Ejemplo n.º 1
0
def getOpt(optDe):
    model = optDe.get('model', 'dehaze')
    opt = Option()
    modelPath, opt.modelDef, opt.ram, opt.padding, opt.align, opt.prepare = mode_switch[
        model]
    opt.model = modelPath
    opt.modelCached = initModel(opt, modelPath, model)
    return opt
Ejemplo n.º 2
0
def newOpt(func, ramCoef, align=32, padding=45, scale=1, **_):
  opt = Option()
  opt.modelCached = func
  opt.ramCoef = ramCoef
  opt.align = align
  opt.padding = padding
  opt.scale = scale
  opt.squeeze = identity
  opt.unsqueeze = identity
  return opt
Ejemplo n.º 3
0
def newOpt(f, ramType):
  opt = Option()
  opt.modelCached = lambda x: (f(x),)
  opt.ramCoef = ramCoef[config.getRunType() * 2 + ramType]
  opt.align = 32
  opt.padding = 45
  opt.squeeze = identity
  opt.unsqueeze = identity
  return opt
Ejemplo n.º 4
0
def getOpt(optDN):
  model = optDN['model']
  if not model in mode_switch:
    return
  opt = Option(mode_switch[model][0])
  opt.modelDef = mode_switch[model][1]

  opt.ramCoef = mode_switch[model][2][config.getRunType()]
  opt.cropsize = config.getConfig()[1 if model[:4] == 'lite' else 2]
  opt.modelCached = initModel(opt, opt.model, 'DN' + model)
  sd = mode_switch[model][3]
  if sd:
    opt.fixChannel = 0
    opt.squeeze = lambda x: x.squeeze(sd)
    opt.unsqueeze = lambda x: x.unsqueeze(sd)
  opt.padding = 15
  return opt
Ejemplo n.º 5
0
def getOpt(option):
  opt = Option(modelPath)
  # Initialize model
  dict1 = getStateDict(modelPath)
  flowComp = initModel(opt, dict1['state_dictFC'], 'flowComp', getFlowComp)
  ArbTimeFlowIntrp = initModel(opt, dict1['state_dictAT'], 'ArbTimeFlowIntrp', getFlowIntrp)
  opt.sf = option['sf']
  opt.firstTime = 1
  opt.notLast = 1
  opt.batchSize = 0
  opt.flowBackWarp = None
  opt.optFlow = newOpt(flowComp, 0)
  opt.optArb = newOpt(ArbTimeFlowIntrp, 1)
  if opt.sf < 2:
    raise RuntimeError('Error: --sf/slomo factor has to be at least 2')
  return opt
Ejemplo n.º 6
0
def getOptS(modelPath, modules, ramCoef):
  opt = Option(modelPath)
  weights = getStateDict(modelPath)
  opt.modules = modules
  opt.ramOffset = config.getRunType() * len(modules)
  for i, key in enumerate(modules):
    m = modules[key]
    wKey = m['weight']
    constructor = m.get('f', 0)
    rc = m['ramCoef'][config.getRunType()] if 'ramCoef' in m else ramCoef[opt.ramOffset + i]
    o = dict((k, m[k]) for k in ('align', 'padding', 'scale') if k in m)
    model = initModel(opt, weights[wKey], key, constructor)
    if 'outShape' in m:
      opt.__dict__[key] = newOpt(model, rc, **o)
    else:
      model.ramCoef = rc
      opt.__dict__[key] = model
  return opt
Ejemplo n.º 7
0
def getOpt(optSR):
    opt = Option()
    opt.mode = optSR['model']
    opt.scale = optSR['scale']
    nmode = opt.mode + str(opt.scale)
    if not nmode in mode_switch:
        return
    if opt.mode[:3] != 'gan':
        opt.squeeze = lambda x: x.squeeze(1)
        opt.unsqueeze = lambda x: x.unsqueeze(1)
    opt.padding = 9 if opt.scale == 3 else 5
    opt.model = mode_switch[nmode][0]
    opt.modelDef = mode_switch[nmode][1]
    opt.ensemble = optSR['ensemble'] if 'ensemble' in optSR and (
        0 <= optSR['ensemble'] <= 7) else config.ensembleSR

    opt.ramCoef = mode_switch[nmode][2][config.getRunType()]
    opt.cropsize = config.getConfig()[0]
    opt.modelCached = initModel(opt, opt.model, 'SR' + nmode)
    return opt
Ejemplo n.º 8
0
from imageProcess import toTorch, readFile, initModel, toFloat, toOutput, ensemble, writeFile, Option
from config import config
from time import perf_counter

from moire_obj import Net
modelName = 'moire_obj'
test = False
inputFolder = '../test-pics'
refFile = 0  #'test/1566005911.7879605_ci.png'


def context():
    pass


opt = Option(
    ('test/{}.pth' if test else 'model/demoire/{}.pth').format(modelName))
opt.padding = 31
opt.ramCoef = 1 / 8000.
opt.align = 128
opt.modelCached = initModel(opt, weights=opt.model, f=lambda _: Net())
toTorch = lambda x: torch.from_numpy(np.array(x)).permute(2, 0, 1).to(
    dtype=config.dtype(), device=config.device()) / 256
time = 0.0
for pic in os.listdir(inputFolder):
    original = toTorch(readFile(context=context)(inputFolder + '/' + pic))
    ref = toTorch(readFile(context=context)(refFile + '/' +
                                            pic)) if refFile else original
    start = perf_counter()
    y = ensemble(opt)(original)
    time += perf_counter() - start
    print(pic, float(y.mean(dtype=torch.float)),