def train(training_config):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    perceptual_loss_net = PerceptualLossNet(requires_grad=False).to(device)
    transformer_net = TransformerNet().train().to(device)

    optimizer = torch.optim.Adam(transformer_net.parameters(), lr=training_config['learning_rate'])
    style_img_path = os.path.join(training_config['style_images_dir'], training_config['style_image_name'])
    style_image = prepare_img(style_img_path, training_config['height'], device)
    features_for_style_image = perceptual_loss_net(style_image)
    style_targets = features_for_style_image.style_features

    train_loader = get_training_data_loader(training_config)
    writer_stylized = SummaryWriter(f'runs/stylized')
    writer_real = SummaryWriter(f'runs/real')

    step = 0
    for epoch in range(training_config['num_of_epochs']):
        for batch_id, (content_batch, _) in enumerate(train_loader):
            # step1: Calculate the content targets
            content_batch = content_batch.to(device)
            features_for_content_batch = perceptual_loss_net(content_batch)
            content_targets = features_for_content_batch.content_features

            # step2: Calculate the content_features and style_features for stylized_content_batch
            stylized_content_batch = transformer_net(content_batch)
            features_for_stylized_content_batch = perceptual_loss_net(stylized_content_batch)
            content_features = features_for_stylized_content_batch.content_features
            style_features = features_for_stylized_content_batch.style_features

            # step3: calculate content_loss, style_loss, tv_loss
            content_loss = training_config['content_weight'] * calculate_content_loss(content_targets, content_features) 
            style_loss = training_config['style_weight'] * calculate_style_loss(style_targets, style_features)
            tv_loss = training_config['tv_weight'] * calculate_tv_loss(stylized_content_batch)

            # step4: calculate total_loss
            total_loss = content_loss + style_loss + tv_loss

            total_loss.backward()
            optimizer.step()
            optimizer.zero_grad() # clear gradients for the next round

            if batch_id % 20 == 0:
                print(f'Epochs: {epoch} | total_loss: {total_loss.item():12.4f} | content_loss: {content_loss.item():12.4f} | style_loss: {style_loss.item():12.4f} | tv_loss: {tv_loss.item():12.4f}')
                
                display_image(content_batch)
                display_image(stylized_content_batch)
                
                with torch.no_grad():
                    image_grid_stylized = torchvision.utils.make_grid(stylized_content_batch, normalize=True)
                    image_grid_real = torchvision.utils.make_grid(content_batch, normalize=True)
                    writer_stylized.add_image('stylized image', image_grid_stylized, global_step=step)
                    writer_real.add_image('original image', image_grid_real, global_step=step)

                training_state = get_training_metadata(training_config)
                training_state['state_dict'] = transformer_net.state_dict()
                training_state['optimizer_state'] = optimizer.state_dict()
                model_name = f'style_{training_config["style_image_name"].split(".")[0]}_datapoints_{training_state["num_of_datapoints"]}_cw_{str(training_config["content_weight"])}_sw_{str(training_config["style_weight"])}_tw_{str(training_config["tv_weight"])}.pth'
                torch.save(training_state, os.path.join(training_config["model_binaries_path"], model_name))
Ejemplo n.º 2
0
 def __init__(self):
     self.transformer = TransformerNet().to(device)
     self.extracter = FeatureNet().eval().to(device)
     self.lr = 1e-3
     self.optimizer = t.optim.Adam(self.transformer.parameters(), self.lr)
     self.content_weight = 1e5
     self.style_weight = 1e10
     #style1 1e10
     #style2 1e11
     self.epoches = 3
Ejemplo n.º 3
0
def stylize(**kwargs):
    opt = Config()

    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)
    device = t.device('cuda' if t.cuda.is_available() else 'cpu')

    # 图片处理
    content_image = PIL.Image.open(opt.content_path)
    content_transform = tv.transforms.Compose(
        [tv.transforms.ToTensor(),
         tv.transforms.Lambda(lambda x: x * 255)])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device).detach()

    # 模型
    style_model = TransformerNet().eval()
    style_model.load_state_dict(
        t.load(opt.model_path, map_location=t.device('cpu')))
    style_model.to(device)

    # 风格迁移和保存
    output = style_model(content_image)
    output_data = output.cpu().data[0]
    tv.utils.save_image((output_data / 255).clamp(min=0, max=1),
                        opt.result_path)
def prepare_model(inference_config, device):
    stylization_model = TransformerNet().to(device)
    training_state = torch.load(os.path.join(inference_config['model_binaries_path'], inference_config['model_name']), map_location=torch.device('cpu'))
    state_dict = training_state['state_dict']
    stylization_model.load_state_dict(state_dict, strict=True)
    stylization_model.eval()
    return stylization_model
Ejemplo n.º 5
0
def stylize(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    content_image = utils.load_image(args.content_image,
                                     scale=args.content_scale)
    content_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device)

    if args.model.endswith(".onnx"):
        output = stylize_onnx_caffe2(content_image, args)
    else:
        with torch.no_grad():
            style_model = TransformerNet().eval()
            state_dict = torch.load(args.model)
            # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
            for k in list(state_dict.keys()):
                if re.search(r'in\d+\.running_(mean|var)$', k):
                    del state_dict[k]
            style_model.load_state_dict(state_dict)
            style_model.to(device)
            if args.export_onnx:
                assert args.export_onnx.endswith(
                    ".onnx"), "Export model file should end with .onnx"
                output = torch.onnx._export(style_model, content_image,
                                            args.export_onnx).cpu()
            else:
                output = style_model(content_image).cpu()
    utils.save_image(args.output_image, output[0])
Ejemplo n.º 6
0
def stylize(imgname):
    transformer_name = './models/model_exp_{}.pt'.format(exp_num)
    content_image = load_image(imgname)
    content_image = style_transform(content_image)
    content_image = content_image.unsqueeze(0)
    content_image = Variable(content_image)

    transformer = TransformerNet()
    model_dict = torch.load(transformer_name)
    transformer.load_state_dict(model_dict)

    if use_cuda:
        transformer.cuda()
        content_image = content_image.cuda()

    o = transformer(content_image)
    y = o.data.cpu()[0]
    name, backend = os.path.splitext(os.path.basename(imgname))
    save_style_name = os.path.join(
        os.path.dirname(imgname),
        '{}_stylized_{}{}'.format(name, exp_num, backend))
    save_image(save_style_name, y)
Ejemplo n.º 7
0
def test():
    pretrain, code_id, word_id = source_prepare()
    print('%d different words, %d different codes' % (len(word_id), len(code_id)), flush=True)

    Acc = 0.
    for fold in range(5):
        Net = TransformerNet(Pretrain_type, pretrain, Max_seq_len, Embedding_size, Inner_hid_size, len(code_id), D_k,
                             D_v, dropout_ratio=Dropout, num_layers=Num_layers, num_head=Num_head, Freeze=Freeze_emb).cuda()
        Net.load_state_dict(torch.load(SAVE_DIR + 'data1_false_biobert_' + str(fold) + '_599'))
        Net.eval()

        # test_file = DATA_path + 'AskAPatient/AskAPatient.fold-' + str(fold) + '.test.txt'
        test_file = DATA_path + 'test_' + str(fold) + '.csv'
        test_data = tokenizer(word_id, code_id, test_file, pretrain_type=Pretrain_type)

        print('Fold %d: %d test data' % (fold, len(test_data.data)))
        print('max length: %d' % test_data.max_length, flush=True)

        test_data.reset_epoch()
        test_correct = 0
        i = 0
        while not test_data.epoch_finish:
            seq, label, seq_length, mask, seq_pos, standard_emb = test_data.get_batch(1)
            results = Net(seq, seq_pos, standard_emb)
            _, idx = results.max(1)
            test_correct += len((idx == label).nonzero())
            i += 1
        assert i == len(test_data.data)
        test_accuracy = float(test_correct) / float(i)

        print('[fold %d] test: %d correct, %.4f accuracy' % (fold, test_correct, test_accuracy), flush=True)

        Acc += test_accuracy

        del test_data
        gc.collect()

    print('finial validation accuracy: %.4f' % (Acc / 5), flush=True)
Ejemplo n.º 8
0
def stylize(args):
    device = torch.device("cpu")

    content_image = utils.load_image(args.content_image, scale=args.content_scale)
    content_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device)

    with torch.no_grad():
        style_model = TransformerNet()
        state_dict = torch.load(args.model)
        # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
        for k in list(state_dict.keys()):
            if re.search(r'in\d+\.running_(mean|var)$', k):
                del state_dict[k]
        style_model.load_state_dict(state_dict)
        style_model.to(device)
        output = style_model(content_image)
       
    utils.save_image(args.output_image, output[0])
Ejemplo n.º 9
0
class TransModel():
    def __init__(self):
        self.transformer = TransformerNet().to(device)
        self.extracter = FeatureNet().eval().to(device)
        self.lr = 1e-3
        self.optimizer = t.optim.Adam(self.transformer.parameters(), self.lr)
        self.content_weight = 1e5
        self.style_weight = 1e10
        #style1 1e10
        #style2 1e11
        self.epoches = 3

    def train(self, dataloader, style):
        with t.no_grad():
            features_style = self.extracter(style)
            gram_style = [gram_matrix(y) for y in features_style]
            #Bx64x64 的channel相似度矩阵

        for epoch in range(self.epoches):
            for i, (x, _) in tqdm.tqdm(enumerate(dataloader)):

                self.optimizer.zero_grad()

                x = x.to(device)
                y = self.transformer(x)

                if i > 100:
                    break
                #y = batch_normalize(y)
                #x = batch_normalize(x)
                features_y = self.extracter(y)
                features_x = self.extracter(x)

                #使用relu2的值计算内容的损失
                content_loss = self.content_weight * F.mse_loss(
                    features_y.relu2, features_x.relu2)

                gram_y = [gram_matrix(y) for y in features_y]
                style_loss = 0
                for i in range(len(gram_y)):
                    style_loss += F.mse_loss(
                        gram_y[i], gram_style[i].expand_as(gram_y[i]))
                style_loss = self.style_weight * style_loss
                #print(style_loss)

                loss = content_loss + style_loss
                loss.backward()
                self.optimizer.step()

            if epoch % 1 == 0:
                plt.figure()

                origin_img = x.data.cpu()[1].permute(1, 2, 0)
                style_img = style.cpu()[0].permute(1, 2, 0)
                new_img = y.data.cpu()[1].permute(1, 2, 0)

                plt.subplot(131)
                plt.imshow(origin_img)
                plt.xticks([]), plt.yticks([])
                plt.subplot(132)
                plt.imshow(style_img)
                plt.xticks([]), plt.yticks([])
                plt.subplot(133)
                plt.imshow(new_img)
                plt.xticks([]), plt.yticks([])

                plt.savefig('./dump/' + str(epoch) + '.png')
                plt.close()
                #path = './dump/' + str(epoch) +'.png'
                #tv.utils.save_image(y.data.cpu()[0].clamp(min=0, max=1), path)

    def stylise(self, style, content, save_path):
        plt.figure()

        origin_img = content.cpu()[0].permute(1, 2, 0)
        style_img = style.cpu()[0].permute(1, 2, 0)

        y = self.transformer(content)
        new_img = y.data.cpu()[0].permute(1, 2, 0)

        plt.subplot(131)
        plt.imshow(origin_img)
        plt.xticks([]), plt.yticks([])
        plt.subplot(132)
        plt.imshow(style_img)
        plt.xticks([]), plt.yticks([])
        plt.subplot(133)
        plt.imshow(new_img)
        plt.xticks([]), plt.yticks([])

        plt.savefig(save_path)
        plt.close()
Ejemplo n.º 10
0
def train():
    pretrain, code_id, word_id = source_prepare()
    print('%d different words, %d different codes' % (len(word_id), len(code_id)), flush=True)

    criterion = nn.CrossEntropyLoss()

    Acc = 0.
    for fold in range(5):
        Net = TransformerNet(Pretrain_type, pretrain, Max_seq_len, Embedding_size, Inner_hid_size, len(code_id), D_k,
                             D_v, dropout_ratio=Dropout, num_layers=Num_layers, num_head=Num_head, Freeze=Freeze_emb).cuda()
        optimizer = optim.Adam(Net.parameters(), lr=Learning_rate, eps=1e-08, weight_decay=Weight_decay)

        # train_file = DATA_path + 'AskAPatient/AskAPatient.fold-' + str(fold) + '.train.txt'
        # val_file = DATA_path + 'AskAPatient/AskAPatient.fold-' + str(fold) + '.validation.txt'
        train_file = DATA_path + 'trainsplit_' + str(fold) + '.csv'
        val_file = DATA_path + 'val_' + str(fold) + '.csv'
        train_data = tokenizer(word_id, code_id, train_file, pretrain_type=Pretrain_type)
        val_data = tokenizer(word_id, code_id, val_file, pretrain_type=Pretrain_type)

        print('Fold %d: %d training data, %d validation data' % (fold, len(train_data.data), len(val_data.data)))
        print('max length: %d %d' % (train_data.max_length, val_data.max_length), flush=True)

        for e in range(Epoch):
            train_data.reset_epoch()
            Net.train()
            while not train_data.epoch_finish:
                optimizer.zero_grad()
                seq, label, seq_length, mask, seq_pos, standard_emb = train_data.get_batch(Batch_size)
                results = Net(seq, seq_pos, standard_emb)
                loss = criterion(results, label)
                loss.backward()
                optimizer.step()

            if (e + 1) % Val_every == 0:
                Net.eval()

                train_data.reset_epoch()
                train_correct = 0
                i = 0
                while not train_data.epoch_finish:
                    seq, label, seq_length, mask, seq_pos, standard_emb = train_data.get_batch(Batch_size)
                    results = Net(seq, seq_pos, standard_emb)
                    _, idx = results.max(1)
                    train_correct += len((idx == label).nonzero())
                    i += Batch_size
                # assert i == len(train_data.data)
                train_accuracy = float(train_correct) / float(i)

                val_data.reset_epoch()
                val_correct = 0
                i = 0
                while not val_data.epoch_finish:
                    seq, label, seq_length, mask, seq_pos, standard_emb = val_data.get_batch(Batch_size)
                    results = Net(seq, seq_pos, standard_emb)
                    _, idx = results.max(1)
                    val_correct += len((idx == label).nonzero())
                    i += Batch_size
                # assert i == len(val_data.data)
                val_accuracy = float(val_correct) / float(len(val_data.data))

                print('[fold %d epoch %d] training loss: %.4f, % d correct, %.4f accuracy;'
                      ' validation: %d correct, %.4f accuracy' %
                      (fold, e, loss.item(), train_correct, train_accuracy, val_correct, val_accuracy), flush=True)

                torch.save(Net.state_dict(), SAVE_DIR + 'data1_false_biobert_' + str(fold) + '_' + str(e))

            if (e + 1) % LR_decay_epoch == 0:
                adjust_learning_rate(optimizer, LR_decay)
                print('learning rate decay!', flush=True)

        Acc += val_accuracy

        del train_data, val_data
        gc.collect()

    print('finial validation accuracy: %.4f' % (Acc / 5))
Ejemplo n.º 11
0
def train(content_img_name=None, style_img_name=None, features=None):
    transformer = TransformerNet()
    # features = Vgg16()

    lr = 0.001
    weight_content = 1e5
    weight_style = 1e10
    optimizer = torch.optim.Adam(transformer.parameters(), lr)
    mse_loss = torch.nn.MSELoss()

    style = load_image(style_img_name)
    style = style_transform(style)
    style = style.unsqueeze(0)
    style_v = Variable(style)
    style_v = normalize_batch(style_v)
    features_style = features(style_v)
    gram_style = [gram_matrix(y) for y in features_style]

    transformer.train()
    x = load_image(content_img_name)
    x = content_transform(x)
    x = x.unsqueeze(0)
    x = Variable(x)

    if use_cuda:
        transformer.cuda()
        features.cuda()
        x = x.cuda()
        gram_style = [gram.cuda() for gram in gram_style]

    # training
    count = 0
    log_name = './logs/log_exp_{}.txt'.format(exp_num)
    log = []
    while count < iteration_total:
        optimizer.zero_grad()

        y = transformer(x)

        y = normalize_batch(y)
        x = normalize_batch(x)

        features_y = features(y)
        features_x = features(x)

        loss_content = mse_loss(features_y[1], features_x[1])

        loss_style = 0.
        for ft_y, gm_s in zip(features_y, gram_style):
            gm_y = gram_matrix(ft_y)
            loss_style = loss_style + mse_loss(gm_y, gm_s)

        total_loss = weight_content * loss_content + weight_style * loss_style
        total_loss.backward()
        optimizer.step()

        # log show
        count += 1
        msg = '{}\titeration: {}\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}\n'.format(
            time.ctime(), count, loss_content.item(), loss_style.item(),
            total_loss.item())
        log.append(msg)
        if count % 50 == 0:
            print(''.join(log))
            with open(log_name, 'a') as f:
                f.writelines(''.join(log))
                log.clear()

    # save model
    transformer.eval()
    if use_cuda:
        transformer.cpu()
    save_model_name = './models/model_exp_{}.pt'.format(exp_num)
    torch.save(transformer.state_dict(), save_model_name)
Ejemplo n.º 12
0
def train(**kwargs):
    opt = Config()
    for _k, _v in kwargs.items():
        setattr(opt, _k, _v)

    device = t.device("cuda" if t.cuda.is_available() else "cpu")
    vis = utils.Visualizer(opt.env)

    # 数据加载
    transforms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Lambda(lambda x: x * 255)
    ])
    dataset = tv.datasets.ImageFolder(opt.data_root, transforms)
    dataloader = data.DataLoader(dataset, opt.batch_size)

    # 风格转换网络
    transformer = TransformerNet()
    if opt.model_path:
        transformer.load_state_dict(
            t.load(opt.model_path, map_location=t.device('cpu')))
    transformer.to(device)

    # 损失网络 Vgg16
    vgg = Vgg16().eval()
    vgg.to(device)
    for param in vgg.parameters():
        param.requires_grad = False

    # 优化器
    optimizer = t.optim.Adam(transformer.parameters(), opt.lr)

    # 获取风格图片的数据
    style = utils.get_style_data(opt.style_path)
    vis.img('style', (style.data[0] * 0.225 + 0.45).clamp(min=0, max=1))
    style = style.to(device)

    # 风格图片的gramj矩阵
    with t.no_grad():
        features_style = vgg(style)
        gram_style = [utils.gram_matrix(y) for y in features_style]

    # 损失统计
    style_loss_avg = 0
    content_loss_avg = 0

    for epoch in range(opt.epoches):
        for ii, (x, _) in tqdm(enumerate(dataloader)):

            # 训练
            optimizer.zero_grad()
            x = x.to(device)
            y = transformer(x)
            # print(y.size())
            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)
            features_x = vgg(x)
            features_y = vgg(y)

            # content loss
            content_loss = opt.content_weight * F.mse_loss(
                features_y.relu3_3, features_x.relu3_3)

            # style loss
            style_loss = 0
            for ft_y, gm_s in zip(features_y, gram_style):
                with t.no_grad():
                    gram_y = utils.gram_matrix(ft_y)
                style_loss += F.mse_loss(gram_y, gm_s.expand_as(gram_y))
            style_loss *= opt.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            content_loss_avg += content_loss.item()
            style_loss_avg += style_loss.item()

            if (ii + 1) % opt.plot_every == 0:
                vis.plot('content_loss', content_loss_avg / opt.plot_every)
                vis.plot('style_loss', style_loss_avg / opt.plot_every)
                content_loss_avg = 0
                style_loss_avg = 0
                vis.img('output',
                        (y.data.cpu()[0] * 0.225 + 0.45).clamp(min=0, max=1))
                vis.img('input', (x.data.cpu()[0] * 0.225 + 0.45).clamp(min=0,
                                                                        max=1))

            if (ii + 1) % opt.save_every == 0:
                vis.save([opt.env])
                t.save(transformer.state_dict(),
                       'checkpoints/%s_style.pth' % (ii + 1))
Ejemplo n.º 13
0
def train(args):
    device = torch.device("cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # load dataset
    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    # load style image
    style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    # load network
    transformer = TransformerNet().to(device)
    vgg = Vgg16(requires_grad=False).to(device)

    # define optimizer and loss function
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        transformer.train()
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            count += len(x)
            optimizer.zero_grad()

            image_original = x.to(device)
            image_transformed = transformer(x)

            image_original = utils.normalize_batch(image_original)
            image_transformed = utils.normalize_batch(image_transformed)

            # extract features for compute content loss
            features_original= vgg(image_original)
            features_transformed = vgg(image_transformed)
            content_loss = args.content_weight * mse_loss(features_transformed.relu3_3, features_original.relu3_3)

             # extract features for compute style loss
            style_loss = 0.
            for ft_y, gm_s in zip(features_transformed, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:len(x), :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            if (batch_id + 1) % 200 == 0:
                print("Epoch {}:[{}/{}]".format(e + 1, count, len(train_dataset)))

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Ejemplo n.º 14
0
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")
    if args.backbone == "vgg":
        content_layer = ['relu_4']
        style_layer = ['relu_2', 'relu_4', 'relu_7', 'relu_10']
    elif args.backbone == "resnet":
        content_layer = ["conv_3"]
        style_layer = ["conv_1", "conv_2", "conv_3", "conv_4"]

    total_layer = list(dict.fromkeys(content_layer + style_layer))

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    if args.backbone == "vgg":
        loss_model = vgg16().eval().to(device)
    elif args.backbone == "resnet":
        loss_model = resnet().eval().to(device)

    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)
    feature_style = loss_model(utils.normalize_batch(style), style_layer)
    gram_style = {
        key: utils.gram_matrix(val)
        for key, val in feature_style.items()
    }

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            feature_y = loss_model(y, total_layer)
            feature_x = loss_model(x, content_layer)

            content_loss = 0.
            for layer in content_layer:
                content_loss += args.content_weight * mse_loss(
                    feature_y[layer], feature_x[layer])

            style_loss = 0.
            for name in style_layer:
                gm_y = utils.gram_matrix(feature_y[name])
                style_loss += mse_loss(gm_y, gram_style[name][:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()
            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)

            if args.checkpoint_model_dir is not None and (
                    batch_id + 1) % args.checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(
                    e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + "_" + str(
            args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)