def evaluate():
    if mx.context.num_gpus() > 0:
        ctx = mx.gpu()
    else:
        ctx = mx.cpu(0)

    # loading configs
    args = Options().parse()
    cfg = Configs(args.config_path)
    # set logging level
    logging.basicConfig(level=logging.INFO)

    # images
    content_image = tensor_load_rgbimage(cfg.content_image,
                                         ctx,
                                         size=cfg.val_img_size,
                                         keep_asp=True)
    style_image = tensor_load_rgbimage(cfg.style_image,
                                       ctx,
                                       size=cfg.val_style_size)
    style_image = preprocess_batch(style_image)
    # model
    style_model = Net(ngf=cfg.ngf)
    style_model.collect_params().load(cfg.val_model, ctx=ctx)
    # forward
    output = style_model(content_image, style_image)
    # save img
    tensor_save_bgrimage(output[0], cfg.output_img)
    logging.info("Save img to {}".format(cfg.output_img))
def train():
    if mx.context.num_gpus() > 0:
        ctx = mx.gpu()
    else:
        raise RuntimeError('There is no GPU device!')

    # loading configs
    args = Options().parse()
    cfg = Configs(args.config_path)
    # set logging level
    logging.basicConfig(level=logging.INFO)
    # set random seed
    np.random.seed(cfg.seed)

    # build dataset and loader
    content_dataset = ImageFolder(cfg.content_dataset, cfg.img_size, ctx=ctx)
    style_dataset = StyleLoader(cfg.style_dataset, cfg.style_size, ctx=ctx)
    content_loader = gluon.data.DataLoader(content_dataset, batch_size=cfg.batch_size, \
                                            last_batch='discard')

    vgg = Vgg16()
    vgg._init_weights(fixed=True, pretrain_path=cfg.vgg_check_point, ctx=ctx)

    style_model = Net(ngf=cfg.ngf)
    if cfg.resume is not None:
        print("Resuming from {} ...".format(cfg.resume))
        style_model.collect_params().load(cfg.resume, ctx=ctx)
    else:
        style_model.initialize(mx.initializer.MSRAPrelu(), ctx=ctx)
    print("Style model:")
    print(style_model)

    # build trainer
    lr_sche = mx.lr_scheduler.FactorScheduler(
        step=170000,
        factor=0.1,
        base_lr=cfg.base_lr
        #warmup_begin_lr=cfg.base_lr/3.0,
        #warmup_steps=300,
    )
    opt = mx.optimizer.Optimizer.create_optimizer('adam', lr_scheduler=lr_sche)
    trainer = gluon.Trainer(style_model.collect_params(), optimizer=opt)

    loss_fn = gluon.loss.L2Loss()

    logging.info("Start training with total {} epoch".format(cfg.total_epoch))
    iteration = 0
    total_time = 0.0
    num_batch = content_loader.__len__() * cfg.total_epoch
    for epoch in range(cfg.total_epoch):
        sum_content_loss = 0.0
        sum_style_loss = 0.0
        for batch_id, content_imgs in enumerate(content_loader):
            iteration += 1
            s = time.time()
            style_image = style_dataset.get(batch_id)

            style_vgg_input = subtract_imagenet_mean_preprocess_batch(
                style_image.copy())
            style_image = preprocess_batch(style_image)
            style_features = vgg(style_vgg_input)
            style_features = [
                style_model.gram.gram_matrix(mx.nd, f) for f in style_features
            ]

            content_vgg_input = subtract_imagenet_mean_preprocess_batch(
                content_imgs.copy())
            content_features = vgg(content_vgg_input)[1]

            with autograd.record():
                y = style_model(content_imgs, style_image)
                y = subtract_imagenet_mean_batch(y)
                y_features = vgg(y)

                content_loss = 2 * cfg.content_weight * loss_fn(
                    y_features[1], content_features)
                style_loss = 0.0
                for m in range(len(y_features)):
                    gram_y = style_model.gram.gram_matrix(mx.nd, y_features[m])
                    _, C, _ = style_features[m].shape
                    gram_s = mx.nd.expand_dims(style_features[m],
                                               0).broadcast_to((
                                                   gram_y.shape[0],
                                                   1,
                                                   C,
                                                   C,
                                               ))
                    style_loss = style_loss + 2 * cfg.style_weight * loss_fn(
                        gram_y, gram_s)
                total_loss = content_loss + style_loss
                total_loss.backward()

            trainer.step(cfg.batch_size)
            mx.nd.waitall()
            e = time.time()
            total_time += e - s
            sum_content_loss += content_loss[0]
            sum_style_loss += style_loss[0]
            if iteration % cfg.log_interval == 0:
                itera_sec = total_time / iteration
                eta_str = str(
                    datetime.timedelta(seconds=int((num_batch - iteration) *
                                                   itera_sec)))
                mesg = "{} Epoch [{}]:\t[{}/{}]\tTime:{:.2f}s\tETA:{}\tlr:{:.4f}\tcontent: {:.3f}\tstyle: {:.3f}\ttotal: {:.3f}".format(
                    time.strftime("%H:%M:%S",
                                  time.localtime()), epoch + 1, batch_id + 1,
                    content_loader.__len__(), itera_sec, eta_str,
                    trainer.optimizer.learning_rate,
                    sum_content_loss.asnumpy()[0] / (batch_id + 1),
                    sum_style_loss.asnumpy()[0] / (batch_id + 1),
                    (sum_content_loss + sum_style_loss).asnumpy()[0] /
                    (batch_id + 1))
                logging.info(mesg)
                ctx.empty_cache()
        save_model_filename = "Epoch_" + str(epoch + 1) +  "_" + str(time.ctime()).replace(' ', '_') + \
                "_" + str(cfg.content_weight) + "_" + str(cfg.style_weight) + ".params"
        if not os.path.isdir(cfg.save_model_dir):
            os.mkdir(cfg.save_model_dir)
        save_model_path = os.path.join(cfg.save_model_dir, save_model_filename)
        logging.info("Saving parameters to {}".format(save_model_path))
        style_model.collect_params().save(save_model_path)