コード例 #1
0
ファイル: boost_train.py プロジェクト: 253681319/mxnet
# content
content_mod = basic.get_content_module("content", dshape, ctx, vgg_params)

# loss
loss, gscale = basic.get_loss_module("loss", dshape, ctx, vgg_params)
extra_args = {"target_gram_%d" % i : style_array[i] for i in range(len(style_array))}
loss.set_params(extra_args, {}, True, True)
grad_array = []
for i in range(len(style_array)):
    grad_array.append(mx.nd.ones((1,), ctx) * (float(style_weight) / gscale[i]))
grad_array.append(mx.nd.ones((1,), ctx) * (float(content_weight)))

# generator
gens = [gen_v4.get_module("g0", dshape, ctx),
        gen_v3.get_module("g1", dshape, ctx),
        gen_v3.get_module("g2", dshape, ctx),
        gen_v4.get_module("g3", dshape, ctx)]
for gen in gens:
    gen.init_optimizer(
        optimizer='sgd',
        optimizer_params={
            'learning_rate': 1e-4,
            'momentum' : 0.9,
            'wd': 5e-3,
            'clip_gradient' : 5.0
        })


# tv-loss
def get_tv_grad_executor(img, ctx, tv_weight):
コード例 #2
0
import numpy as np

# import basic
import data_processing
import gen_v3
import gen_v4

dshape = (1, 3, 480, 640)
clip_norm = 1.0 * np.prod(dshape)
model_prefix = "./model/"
ctx = mx.gpu(0)

# generator
gens = [
    gen_v4.get_module("g0", dshape, ctx),
    gen_v3.get_module("g1", dshape, ctx),
    gen_v3.get_module("g2", dshape, ctx),
    gen_v4.get_module("g3", dshape, ctx)
]
for i in range(len(gens)):
    gens[i].load_params("./model/%d/v3_0002-0026000.params" % i)

content_np = data_processing.PreprocessContentImage("../IMG_4343.jpg",
                                                    min(dshape[2:]), dshape)
data = [mx.nd.array(content_np)]
for i in range(len(gens)):
    gens[i].forward(mx.io.DataBatch([data[-1]], [0]), is_train=False)
    new_img = gens[i].get_outputs()[0]
    data.append(new_img.copyto(mx.cpu()))
    data_processing.SaveImage(new_img.asnumpy(), "out_%d.jpg" % i)