# Load Style Target
        style_img = Image.open(opts.style_image).convert('RGB')
        with torch.no_grad():
            style_img_tensor = transforms.Compose([
                transforms.Resize(opts.image_size * 2),
                transforms.ToTensor(),
                tensor_normalizer()
            ])(style_img).unsqueeze(0)
            style_img_tensor = style_img_tensor.to(device)

        # Precalculate Gram Matrices of the Style Image
        # http://pytorch.org/docs/master/notes/autograd.html#volatile
        with torch.no_grad():
            style_loss_features = loss_network(style_img_tensor)
            gram_style = [gram_matrix(y) for y in style_loss_features]
        print('# of VGG-19 layers which style loss use:',
              style_loss_features._fields)

        #for i in range(len(style_loss_features)):
        #    tmp = style_loss_features[i].cpu().numpy()
        #    print(i, np.mean(tmp), np.std(tmp))

        #for i in range(len(style_loss_features)):
        #    print(i, gram_style[i].numel(), gram_style[i].size())

        # Train the Transformer
        torch.set_default_tensor_type('torch.FloatTensor')
        mse_loss = torch.nn.MSELoss()
        # l1_loss = torch.nn.L1Loss()
def train(transformer, loss_network, gram_style, gram_matrix, train_loader,\
              content_weight, regularization, style_weights, log_interval,\
              optimizer, device, steps, base_steps=0):
    transformer.train()
    count = 0
    agg_content_loss = 0.
    agg_style_loss = 0.
    agg_reg_loss = 0.
    while True:
        for x, _ in train_loader:
            count += 1
            optimizer.zero_grad()
            x = x.to(device)
            y = transformer(x)

            with torch.no_grad():
                xc = x.detach()

            features_y = loss_network(y)
            features_xc = loss_network(xc)

            with torch.no_grad():
                f_xc_c = features_xc[2].detach()

            content_loss = content_weight * mse_loss(features_y[2], f_xc_c)

            reg_loss = regularization * (
                torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])) +
                torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :])))

            style_loss = 0.
            for l, weight in enumerate(style_weights):
                gram_s = gram_style[l]
                gram_y = gram_matrix(features_y[l])
                style_loss += float(weight) * mse_loss(
                    gram_y, gram_s.expand_as(gram_y))

            total_loss = content_loss + style_loss + reg_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss
            agg_style_loss += style_loss
            agg_reg_loss += reg_loss

            if count % log_interval == 0:
                mesg = "{} [{}/{}] content: {:.2f}  style: {:.2f}  reg: {:.2f} total: {:.6f}".format(
                    time.ctime(), count, steps,
                    agg_content_loss / log_interval,
                    agg_style_loss / log_interval, agg_reg_loss / log_interval,
                    (agg_content_loss + agg_style_loss + agg_reg_loss) /
                    log_interval)
                print(mesg)
                agg_content_loss = 0.
                agg_style_loss = 0.
                agg_reg_loss = 0.
                agg_stable_loss = 0.
                transformer.eval()
                y = transformer(x)
                save_debug_image(
                    x, y.detach(),
                    "./fast-neural-style/debug_{}/{}.png".format(
                        opts.style_name, base_steps + count))
                transformer.train()

            if count >= steps:
                return
def train(steps, base_steps=0):
    transformer.train()
    count = 0
    agg_content_loss = 0.
    agg_style_loss = 0.
    agg_reg_loss = 0.
    agg_stable_loss = 0.
    while True:
        for x, _ in train_loader:
            count += 1
            optimizer.zero_grad()
            x = x.to(device)
            y = transformer(x)
            with torch.no_grad():
                mask = torch.bernoulli(
                    torch.ones_like(x, device=device, dtype=torch.float) *
                    NOISE_P)
                noise = torch.normal(
                    torch.zeros_like(x),
                    torch.ones_like(x, device=device, dtype=torch.float) *
                    NOISE_STD).clamp(-1, 1)
                # print((noise * mask).sum())
            y_noise = transformer(x + noise * mask)

            with torch.no_grad():
                xc = x.detach()
                features_xc = loss_network(xc)

            features_y = loss_network(y)

            with torch.no_grad():
                f_xc_c = features_xc[2].detach()

            content_loss = CONTENT_WEIGHT * mse_loss(features_y[2], f_xc_c)

            reg_loss = REGULARIZATION * (
                torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])) +
                torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :])))

            style_loss = 0.
            for l, weight in enumerate(STYLE_WEIGHTS):
                gram_s = gram_style[l]
                gram_y = gram_matrix(features_y[l])
                style_loss += float(weight) * mse_loss(
                    gram_y, gram_s.expand_as(gram_y))

            stability_loss = NOISE_WEIGHT * mse_loss(y_noise.view(-1),
                                                     y.view(-1).detach())

            total_loss = content_loss + style_loss + reg_loss + stability_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss
            agg_style_loss += style_loss
            agg_reg_loss += reg_loss
            agg_stable_loss += stability_loss

            if count % LOG_INTERVAL == 0:
                mesg = "{} [{}/{}] content: {:.2f}  style: {:.2f}  reg: {:.2f} stable: {:.2f} total: {:.6f}".format(
                    time.ctime(), count, steps,
                    agg_content_loss / LOG_INTERVAL,
                    agg_style_loss / LOG_INTERVAL, agg_reg_loss / LOG_INTERVAL,
                    agg_stable_loss / LOG_INTERVAL,
                    (agg_content_loss + agg_style_loss + agg_reg_loss +
                     agg_stable_loss) / LOG_INTERVAL)
                print(mesg)
                agg_content_loss = 0.
                agg_style_loss = 0.
                agg_reg_loss = 0.
                agg_stable_loss = 0.
                transformer.eval()
                y = transformer(x)
                save_debug_image(x, y.detach(), y_noise.detach(),
                                 "../debug/{}.png".format(base_steps + count))
                transformer.train()

            if count >= steps:
                return