imP2PredBatch = Variable(
    torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize))
imP3PredBatch = Variable(
    torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize))

# Refine Network
encoderRefs, albedoRefs = [], []
normalRefs, roughRefs = [], []
depthRefs, envRefs = [], []

encoderRefs.append(models.refineEncoder())
albedoRefs.append(models.refineDecoder(mode=0))
normalRefs.append(models.refineDecoder(mode=1))
roughRefs.append(models.refineDecoder(mode=2))
depthRefs.append(models.refineDecoder(mode=3))
envRefs.append(models.refineEnvDecoder())

# Global illumination
globIllu1to2 = models.globalIllumination()
globIllu2to3 = models.globalIllumination()

encoderRefs[0].load_state_dict(
    torch.load('{0}/encoderRefs{1}_{2}.pth'.format(opt.modelRoot,
                                                   opt.cascadeLevel,
                                                   opt.epochId)))
albedoRefs[0].load_state_dict(
    torch.load('{0}/albedoRefs{1}_{2}.pth'.format(opt.modelRoot,
                                                  opt.cascadeLevel,
                                                  opt.epochId)))
normalRefs[0].load_state_dict(
    torch.load('{0}/normalRefs{1}_{2}.pth'.format(opt.modelRoot,
예제 #2
0
encoderRefs, albedoRefs = [], []
normalRefs, roughRefs = [], []
depthRefs, envRefs = [], []

encoderRefs.append(
    nn.DataParallel(models.refineEncoder(), device_ids=opt.deviceIds))
albedoRefs.append(
    nn.DataParallel(models.refineDecoder(mode=0), device_ids=opt.deviceIds))
normalRefs.append(
    nn.DataParallel(models.refineDecoder(mode=1), device_ids=opt.deviceIds))
roughRefs.append(
    nn.DataParallel(models.refineDecoder(mode=2), device_ids=opt.deviceIds))
depthRefs.append(
    nn.DataParallel(models.refineDecoder(mode=3), device_ids=opt.deviceIds))
envRefs.append(
    nn.DataParallel(models.refineEnvDecoder(), device_ids=opt.deviceIds))

renderLayer = models.renderingLayer(gpuId=opt.gpuId, isCuda=opt.cuda)

# Global illumination
globIllu1to2 = models.globalIllumination()
globIllu2to3 = models.globalIllumination()
#########################################

#########################################
# Load weight of network
globIllu1to2.load_state_dict(
    torch.load('{0}/globIllu1to2_{1}.pth'.format(opt.modelRootGlob,
                                                 opt.epochIdGlob)))
globIllu2to3.load_state_dict(
    torch.load('{0}/globIllu2to3_{1}.pth'.format(opt.modelRootGlob,