Example #1
0
def evalByTrain(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    #different dataset
    train_dataset = datasets.ImageFolder(args.eval_dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    transformer = TransformerNet()
    state_dict = torch.load(args.model)
    # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
    for k in list(state_dict.keys()):
        if re.search(r'in\d+\.running_(mean|var)$', k):
            del state_dict[k]
    transformer.load_state_dict(state_dict)
    transformer.to(device)

    #use loaded model instead
    #transformer = TransformerNet().to(device)

    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]
    # eval once
    for e in range(1):
        transformer.eval()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            #optimizer.zero_grad()

            x = x.to(device)
            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            content_loss = args.content_weight * mse_loss(
                features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            #total_loss.backward()
            #optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

        mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
            time.ctime(), e + 1, count, len(train_dataset),
            agg_content_loss / (batch_id + 1), agg_style_loss / (batch_id + 1),
            (agg_content_loss + agg_style_loss) / (batch_id + 1))
        print(mesg)
Example #2
0
def train():
    device = torch.device("cuda")

    np.random.seed(random_seed)
    torch.manual_seed(random_seed)

    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
    ])

    train_dataset = datasets.ImageFolder(dataset_path, transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size)

    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style = Image.open(style_image_path)
    style = transform(style)
    style = style.repeat(batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    #model_fcrn = FCRN_for_transfer(batch_size = batch_size , requires_grad=False).to(device)
    #model_fcrn_par = torch.load(FCRN_path)
    #start_epoch = model_fcrn_par['epoch']
    #model_fcrn.load_state_dict(model_fcrn_par['state_dict'])
    #print("=> loaded checkpoint '{}' (epoch {})".format(FCRN_path, model_fcrn_par['epoch']))

    for e in range(epochs):
        transformer.train()
        agg_content_loss = 0.
        #agg_depth_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            #depth_y = model_fcrn(y)
            #depth_x = model_fcrn(x)

            content_loss = content_weight * mse_loss(features_y.relu2_2,
                                                     features_x.relu2_2)
            #depth_loss = depth_weight*mse_loss(depth_y,depth_x)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= style_weight

            #total_loss = content_loss + depth_loss + style_loss
            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            #agg_depth_loss += depth_loss.item()
            agg_style_loss += style_loss.item()

            if (batch_id + 1) % log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)

            if checkpoint_model_dir is not None and (
                    batch_id + 1) % checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(
                    e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + "_" + str(
            content_weight) + "_" + str(style_weight) + ".model"
    save_model_path = os.path.join(save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #3
0
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        kwargs = {'num_workers': 0, 'pin_memory': False}
    else:
        kwargs = {}

    transform = transforms.Compose([transforms.Scale(args.image_size),
                                    transforms.CenterCrop(args.image_size),
                                    transforms.ToTensor(),
                                    transforms.Lambda(lambda x: x.mul(255))])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, **kwargs)

    transformer = TransformerNet()
    if (args.premodel != ""):
        transformer.load_state_dict(torch.load(args.premodel))
        print("load pretrain model:"+args.premodel)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16()
    utils.init_vgg16(args.vgg_model_dir)
    vgg.load_state_dict(torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight")))

    if args.cuda:
        transformer.cuda()
        vgg.cuda()

    style = utils.tensor_load_rgbimage(args.style_image, size=args.style_size)
    style = style.repeat(args.batch_size, 1, 1, 1)
    style = utils.preprocess_batch(style)
    if args.cuda:
        style = style.cuda()
    style_v = Variable(style, volatile=True)
    style_v = utils.subtract_imagenet_mean_batch(style_v)
    features_style = vgg(style_v)
    gram_style = [utils.gram_matrix(y) for y in features_style]


    hori=0 
    writer = SummaryWriter(args.logdir,comment=args.logdir)
    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        agg_cate_loss = 0.
        agg_cam_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()
            x = Variable(utils.preprocess_batch(x))
            if args.cuda:
                x = x.cuda()
            y = transformer(x)  
            xc = Variable(x.data.clone(), volatile=True)
            #print(y.size()) #(4L, 3L, 224L, 224L)

            
            # Calculate focus loss and category loss
            y_cam = utils.depreprocess_batch(y)
            y_cam = utils.subtract_mean_std_batch(y_cam) 
            
            xc_cam = utils.depreprocess_batch(xc)
            xc_cam = utils.subtract_mean_std_batch(xc_cam)
            

            del features_blobs[:]
            logit_x = net(xc_cam)
            logit_y = net(y_cam)
            
            label=[]
            cam_loss = 0
            for i in range(len(xc_cam)):
                h_x = F.softmax(logit_x[i])
                probs_x, idx_x = h_x.data.sort(0, True)
                label.append(idx_x[0])
                
                h_y = F.softmax(logit_y[i])
                probs_y, idx_y = h_y.data.sort(0, True)
                
                x_cam = returnCAM(features_blobs[0][i], weight_softmax, idx_x[0])
                x_cam = Variable(x_cam.data,requires_grad = False)
 
                y_cam = returnCAM(features_blobs[1][i], weight_softmax, idx_y[0])
                
                cam_loss += mse_loss(y_cam, x_cam)
            
            #the focus loss
            cam_loss *= 80
            #the category loss
            label = Variable(torch.LongTensor(label),requires_grad = False).cuda()
            cate_loss = 10000 * torch.nn.CrossEntropyLoss()(logit_y,label)
         
         

           
            y = utils.subtract_imagenet_mean_batch(y)
            xc = utils.subtract_imagenet_mean_batch(xc)

            features_y = vgg(y)
            features_xc = vgg(xc)

            #f_xc_c = Variable(features_xc[1].data, requires_grad=False)
            #content_loss = args.content_weight * mse_loss(features_y[1], f_xc_c)


            f_xc_c = Variable(features_xc[2].data, requires_grad=False)
            content_loss = args.content_weight * mse_loss(features_y[2], f_xc_c)
            style_loss = 0.
            for m in range(len(features_y)):
                gram_s = Variable(gram_style[m].data, requires_grad=False)
                gram_y = utils.gram_matrix(features_y[m])
                style_loss += args.style_weight * mse_loss(gram_y, gram_s[:n_batch, :, :])
            #add the total four loss and backward
            total_loss = style_loss + content_loss  + cam_loss + cate_loss
            total_loss.backward()
            optimizer.step()

            #something for display
            agg_content_loss += content_loss.data[0]
            agg_style_loss += style_loss.data[0]
            agg_cate_loss += cate_loss.data[0]
            agg_cam_loss += cam_loss.data[0]
            
            writer.add_scalar("Loss_Cont", agg_content_loss / (batch_id + 1), hori)
            writer.add_scalar("Loss_Style", agg_style_loss / (batch_id + 1), hori)
            writer.add_scalar("Loss_CAM", agg_cam_loss / (batch_id + 1), hori)
            writer.add_scalar("Loss_Cate", agg_cate_loss / (batch_id + 1), hori)
            hori += 1
            
            if (batch_id + 1) % args.log_interval == 0:
               mesg = "{}Epoch{}:[{}/{}] content:{:.2f} style:{:.2f} cate:{:.2f} cam:{:.2f}  total:{:.2f}".format(
                    time.strftime("%a %H:%M:%S"),e + 1, count, len(train_dataset),
                                 agg_content_loss / (batch_id + 1),
                                 agg_style_loss / (batch_id + 1),
                                 agg_cate_loss / (batch_id + 1),
                                 agg_cam_loss / (batch_id + 1),
                                 (agg_content_loss + agg_style_loss + agg_cate_loss + agg_cam_loss ) / (batch_id + 1)
               )
               print(mesg)
               
            if (batch_id + 1) % 2500 == 0:    
                transformer.eval()
                transformer.cpu()
                save_model_filename = "epoch_" + str(e+1) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
                    args.content_weight) + "_" + str(args.style_weight) + ".model"
                save_model_path = os.path.join(args.save_model_dir, save_model_filename)
                torch.save(transformer.state_dict(), save_model_path)
                transformer.cuda()
                transformer.train()
                print("saved at ",count)
    
    
    
    
    # save model
    transformer.eval()
    transformer.cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
        args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)
    
    writer.close()
    print("\nDone, trained model saved at", save_model_path)
Example #4
0
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            content_loss = args.content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                                  agg_content_loss / (batch_id + 1),
                                  agg_style_loss / (batch_id + 1),
                                  (agg_content_loss + agg_style_loss) / (batch_id + 1)
                )
                print(mesg)

            if args.checkpoint_model_dir is not None and (batch_id + 1) % args.checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
        args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #5
0
def train(args):
    # 将torch.Tensor分配到的设备的对象CPU或GPU
    device = torch.device("cuda" if args.cuda else "cpu")
    # 初始化随机种子
    np.random.seed(args.seed)
    # 为CPU设置种子用于生成随机数
    torch.manual_seed(args.seed)
    """
        将多个transform组合起来使用
    """
    transform = transforms.Compose([
        # 重新设定大小
        transforms.Resize(args.image_size),
        # 将给定的Image进行中心切割
        transforms.CenterCrop(args.image_size),
        # 把Image转成张量Tensor格式,大小范围为[0,1]
        transforms.ToTensor(),
        # 使用lambd作为转换器
        transforms.Lambda(lambda x: x.mul(255))
    ])
    # 使用ImageFolder数据加载器,传入数据集的路径
    # transform:一个函数,原始图片作为输入,返回一个转换后的图片
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    # 把上一步做成的数据集放入Data.DataLoader中,可以生成一个迭代器
    # batch_size:int,每个batch加载多少样本
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)
    # 加载模型TransformerNet到设备上
    transformer = TransformerNet().to(device)
    # 我们选择Adam作为优化器
    optimizer = Adam(transformer.parameters(), args.lr)
    # 均方损失函数
    mse_loss = torch.nn.MSELoss()
    # 加载模型Vgg16到设备上
    vgg = Vgg16(requires_grad=False).to(device)
    # 风格图片的处理
    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    # 载入风格图片
    style = utils.load_image(args.style_image, size=args.style_size)
    # 处理风格图片
    style = style_transform(style)
    # repeat(*sizes)沿着指定的维度重复tensor
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)
    # 特征风格归一化
    features_style = vgg(utils.normalize_batch(style))
    # 风格特征图计算Gram矩阵
    gram_style = [utils.gram_matrix(y) for y in features_style]
    # 迭代训练
    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            # 把梯度置零,也就是把loss关于weight的导数变成0
            optimizer.zero_grad()

            y = transformer(x.to(device))

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x.cuda())
            # 计算内容损失
            content_loss = args.content_weight * mse_loss(
                features_y.relu2_2, features_x.relu2_2)
            # 计算风格损失
            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight
            # 总损失
            total_loss = content_loss + style_loss
            # 反向传播
            total_loss.backward()
            # 更新参数
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            # 准备打印相关信息,args.log_interval是最开头设置的好了的参数
            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)
            # 生成训练好的风格图片模型 and (batch_id + 1) % args.checkpoint_interval == 0
            if args.checkpoint_model_dir is not None:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(
                    e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
                        type=str,
                        required=True,
                        help="Path to checkpoint model")
    args = parser.parse_args()
    print(args)

    os.makedirs("images/outputs", exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    transform = style_transform()

    # Define model and load model checkpoint
    transformer = TransformerNet().to(device)
    transformer.load_state_dict(torch.load(args.checkpoint_model))
    transformer.eval()

    stylized_frames = []
    for frame in tqdm.tqdm(extract_frames(args.video_path),
                           desc="Processing frames"):
        # Prepare input frame
        image_tensor = Variable(transform(frame)).to(device).unsqueeze(0)
        # Stylize image
        with torch.no_grad():
            stylized_image = transformer(image_tensor)
        # Add to frames
        stylized_frames += [deprocess(stylized_image)]

    # Create video from frames
    video_name = args.video_path.split("/")[-1].split(".")[0]
    writer = skvideo.io.FFmpegWriter(f"new-{video_name}.mp4")
Example #7
0
# input TransformerNet parameters
if resume_TransformerNet_from_file:
    if os.path.isfile(TransformerNet_path):
        print("=> loading checkpoint '{}'".format(TransformerNet_path))
        TransformerNet_par = torch.load(TransformerNet_path)
        for k in list(TransformerNet_par.keys()):
            if re.search(r'in\d+\.running_(mean|var)$', k):
                del TransformerNet_par[k]
        model_TransformerNet.load_state_dict(TransformerNet_par)
        print("=> loaded checkpoint '{}'".format(TransformerNet_path))
    else:
        print("=> no checkpoint found at '{}'".format(TransformerNet_path))

model_FCRN.eval()
model_TransformerNet.eval()
with torch.no_grad():
    device = torch.device("cuda")

    input = content_image.resize((304, 228), Image.ANTIALIAS)
    input_transform = transforms.Compose([
        transforms.ToTensor(),
        #transforms.Lambda(lambda x: x.mul(255))
    ])
    input_for_style = input_transform(input)
    input_for_depth = transforms.ToTensor()(input)
    input_for_style = input_for_style.unsqueeze(0).to(device)
    input_for_depth = input_for_depth.unsqueeze(0).to(device)

    output_depth = model_FCRN(input_for_depth)
    output_style = model_TransformerNet(input_for_style)
Example #8
0
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.Resize(args.image_size), # the shorter side is resize to match image_size
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(), # to tensor [0,1]
        transforms.Lambda(lambda x: x.mul(255)) # convert back to [0, 255]
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) # to provide a batch loader

    
    style_image = [f for f in os.listdir(args.style_image)]
    style_num = len(style_image)
    print(style_num)

    transformer = TransformerNet(style_num=style_num).to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose([
        transforms.Resize(args.style_size), 
        transforms.CenterCrop(args.style_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])

    style_batch = []

    for i in range(style_num):
        style = utils.load_image(args.style_image + style_image[i], size=args.style_size)
        style = style_transform(style)
        style_batch.append(style)

    style = torch.stack(style_batch).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            
            if n_batch < args.batch_size:
                break # skip to next epoch when no enough images left in the last batch of current epoch

            count += n_batch
            optimizer.zero_grad() # initialize with zero gradients

            batch_style_id = [i % style_num for i in range(count-n_batch, count)]
            y = transformer(x.to(device), style_id = batch_style_id)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y.to(device))
            features_x = vgg(x.to(device))
            content_loss = args.content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[batch_style_id, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                                  agg_content_loss / (batch_id + 1),
                                  agg_style_loss / (batch_id + 1),
                                  (agg_content_loss + agg_style_loss) / (batch_id + 1)
                )
                print(mesg)

            if args.checkpoint_model_dir is not None and (batch_id + 1) % args.checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_').replace(':', '') + "_" + str(int(
        args.content_weight)) + "_" + str(int(args.style_weight)) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #9
0
def train(args):
    serialNumFile = "serialNum.txt"
    serial = 0
    if os.path.isfile(serialNumFile):
        with open(serialNumFile, "r") as t:
            serial = int(t.read())

    serial += 1
    with open(serialNumFile, "w") as t:
        t.write(str(serial))

    if args.mysql:
        cnx = mysql.connector.connect(user='******',
                                      database='midburn',
                                      password='******')
        cursor = cnx.cursor()
    location = args.dataset.split("/")
    if location[-1] == "":
        location = location[-2]
    else:
        location = location[-1]
    save_model_filename = str(serial) + "_" + extractName(
        args.style_image) + "_" + str(args.epochs) + "_" + str(
            int(args.content_weight)) + "_" + str(int(
                args.style_weight)) + "_size_" + str(
                    args.image_size) + "_dataset_" + str(location) + ".model"
    print(save_model_filename)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    m_epoch = 0
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        #kwargs = {'num_workers': 0, 'pin_memory': False}
        kwargs = {'num_workers': 4, 'pin_memory': True}
    else:
        kwargs = {}

    transform = transforms.Compose([
        transforms.Scale(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              **kwargs)

    transformer = TransformerNet()
    #transformer = ResNeXtNet()
    transformer_type = transformer.__class__.__name__
    optimizer = Adam(transformer.parameters(), args.lr)
    if args.l1:
        loss_criterion = torch.nn.L1Loss()
    else:
        loss_criterion = torch.nn.MSELoss()
    loss_type = loss_criterion.__class__.__name__

    if args.visdom:
        vis = VisdomLinePlotter("Style Transfer: " + transformer_type)
    else:
        vis = None

    vgg = Vgg16()
    utils.init_vgg16(args.vgg_model_dir)
    vgg.load_state_dict(
        torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight")))

    if args.cuda:
        transformer.cuda()
        vgg.cuda()

    if args.model is not None:
        transformer.load_state_dict(torch.load(args.model))
        save_model_filename = save_model_filename + "@@@@@@" + str(
            int(getEpoch(args.model)) + int(args.epochs))
        m_epoch += int(getEpoch(args.model))
        print("loaded model\n")

    for param in vgg.parameters():
        param.requires_grad = False

    with torch.no_grad():
        style = utils.tensor_load_rgbimage(args.style_image,
                                           size=args.style_size)
        style = style.repeat(args.batch_size, 1, 1, 1)
        style = utils.preprocess_batch(style)
        if args.cuda:
            style = style.cuda()

        style = utils.subtract_imagenet_mean_batch(style)
        features_style = vgg(style)
        gram_style = [utils.gram_matrix(y) for y in features_style]
        del features_style
        del style

    # TODO: scheduler and style-loss criterion unused at the moment
    scheduler = StepLR(optimizer, step_size=15000 // args.batch_size)
    style_loss_criterion = torch.nn.CosineSimilarity()
    total_count = 0

    if args.mysql:
        q1 = ("REPLACE INTO `images`(`name`) VALUES ('" + args.style_image +
              "')")
        cursor.execute(q1)
        cnx.commit()
        imgId = cursor.lastrowid

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0

        for batch_id, (x, _) in enumerate(train_loader):

            n_batch = len(x)
            count += n_batch
            total_count += n_batch
            optimizer.zero_grad()
            x = utils.preprocess_batch(x)
            if args.cuda:
                x = x.cuda()

            y = transformer(x)

            y = utils.subtract_imagenet_mean_batch(y)
            xc = utils.subtract_imagenet_mean_batch(x)

            features_y = vgg(y)
            f_xc_c = vgg.content_features(xc)

            content_loss = args.content_weight * loss_criterion(
                features_y[1], f_xc_c)

            style_loss = 0.
            for m in range(len(features_y)):
                gram_s = gram_style[m]
                gram_y = utils.gram_matrix(features_y[m])
                style_loss += loss_criterion(gram_y, gram_s[:n_batch, :, :])
                #style_loss -= style_loss_criterion(gram_y, gram_s[:n_batch, :, :])

            style_loss *= args.style_weight
            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()
            # TODO: enable
            #scheduler.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            if (batch_id + 1) % args.log_interval == 0:
                if args.mysql:
                    q1 = (
                        "REPLACE INTO `statistics`(`imgId`,`epoch`, `iteration_id`, `content_loss`, `style_loss`, `loss`) VALUES ("
                        + str(imgId) + "," + str(int(e) + m_epoch) + "," +
                        str(batch_id) + "," + str(agg_content_loss /
                                                  (batch_id + 1)) + "," +
                        str(agg_style_loss / (batch_id + 1)) + "," + str(
                            (agg_content_loss + agg_style_loss) /
                            (batch_id + 1)) + ")")
                    cursor.execute(q1)
                    cnx.commit()
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}\n".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                sys.stdout.flush()
                print(mesg)
            if vis is not None:
                vis.plot(loss_type, "Content Loss", total_count,
                         content_loss.item())
                vis.plot(loss_type, "Style Loss", total_count,
                         style_loss.item())
                vis.plot(loss_type, "Total Loss", total_count,
                         total_loss.item())

    # save model
    transformer.eval()
    transformer.cpu()

    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #10
0
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    # log content and style weight parameters
    if hvd.rank() == 0:
        run.log('content_weight', np.float(args.content_weight))
        run.log('style_weight', np.float(args.style_weight))

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    train_dataset = datasets.ImageFolder(args.dataset, transform)

    # Horovod: partition dataset among workers using DistributedSampler
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=hvd.size(), rank=hvd.rank())

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              sampler=train_sampler,
                              **kwargs)

    transformer = TransformerNet().to(device)

    # Horovod: broadcast parameters from rank 0 to all other processes
    hvd.broadcast_parameters(transformer.state_dict(), root_rank=0)
    # Horovod: scale learning rate by the number of GPUs
    optimizer = Adam(transformer.parameters(), args.lr * hvd.size())
    # Horovod: wrap optimizer with DistributedOptimizer
    optimizer = hvd.DistributedOptimizer(
        optimizer, named_parameters=transformer.named_parameters())
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    print("starting training...")
    for e in range(args.epochs):
        print("epoch {}...".format(e))
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            content_loss = args.content_weight * mse_loss(
                features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            if (batch_id + 1) % args.log_interval == 0:
                avg_content_loss = agg_content_loss / (batch_id + 1)
                avg_style_loss = agg_style_loss / (batch_id + 1)
                avg_total_loss = (agg_content_loss +
                                  agg_style_loss) / (batch_id + 1)
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_sampler),
                    avg_content_loss, avg_style_loss, avg_total_loss)
                print(mesg)

                # log the losses the run history
                run.log('avg_content_loss', np.float(avg_content_loss))
                run.log('avg_style_loss', np.float(avg_style_loss))
                run.log('avg_total_loss', np.float(avg_total_loss))

            if hvd.rank() == 0 and args.checkpoint_model_dir is not None and (
                    batch_id + 1) % args.checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(
                    e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    if hvd.rank() == 0:
        transformer.eval().cpu()
        if args.export_to_onnx:
            # export model to ONNX format
            dummy_input = torch.randn(1, 3, 1024, 1024, device='cpu')
            save_model_path = os.path.join(args.save_model_dir,
                                           '{}.onnx'.format(args.model_name))
            torch.onnx.export(transformer, dummy_input, save_model_path)
        else:
            save_model_path = os.path.join(args.save_model_dir,
                                           '{}.pth'.format(args.model_name))
            torch.save(transformer.state_dict(), save_model_path)

        print("\nDone, trained model saved at", save_model_path)
Example #11
0
def train(args):
    device = "cuda"
    np.random.seed(args.seed)
    # load path of train images
    train_images = os.listdir(args.dataset)
    train_images = [
        image for image in train_images if not image.endswith("txt")
    ]
    random.shuffle(train_images)
    images_num = len(train_images)
    print("dataset size: %d" % images_num)
    # Initialize transforemer net, optimizer, and loss function
    transformer = TransformerNet().to("cuda")

    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = flow.nn.MSELoss()

    if args.load_checkpoint_dir is not None:
        state_dict = flow.load(args.load_checkpoint_dir)
        transformer.load_state_dict(state_dict)
        print("successfully load checkpoint from " + args.load_checkpoint_dir)

    # load pretrained vgg16
    if args.vgg == "vgg19":
        vgg = vgg19(pretrained=True)
    else:
        vgg = vgg16(pretrained=True)
    vgg = VGG_WITH_FEATURES(vgg.features, requires_grad=False)
    vgg.to("cuda")

    style_image = utils.load_image(args.style_image)
    style_image_recover = recover_image(style_image)
    features_style = vgg(
        utils.normalize_batch(flow.Tensor(style_image).to("cuda")))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.0
        agg_style_loss = 0.0
        count = 0
        for i in range(images_num):
            image = load_image("%s/%s" % (args.dataset, train_images[i]))
            n_batch = 1
            count += n_batch

            x_gpu = flow.tensor(image, requires_grad=True).to("cuda")
            y_origin = transformer(x_gpu)

            x_gpu = utils.normalize_batch(x_gpu)
            y = utils.normalize_batch(y_origin)

            features_x = vgg(x_gpu)
            features_y = vgg(y)
            content_loss = args.content_weight * mse_loss(
                features_y.relu2_2, features_x.relu2_2)
            style_loss = 0.0
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            agg_content_loss += content_loss.numpy()
            agg_style_loss += style_loss.numpy()
            if (i + 1) % args.log_interval == 0:
                if args.style_log_dir is not None:
                    y_recover = recover_image(y_origin.numpy())
                    image_recover = recover_image(image)
                    result = np.concatenate(
                        (style_image_recover, image_recover), axis=1)
                    result = np.concatenate((result, y_recover), axis=1)
                    cv2.imwrite(args.style_log_dir + str(i + 1) + ".jpg",
                                result)
                    print(args.style_log_dir + str(i + 1) + ".jpg" + " saved")
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(),
                    e + 1,
                    count,
                    images_num,
                    agg_content_loss / (i + 1),
                    agg_style_loss / (i + 1),
                    (agg_content_loss + agg_style_loss) / (i + 1),
                )
                print(mesg)

            if (args.checkpoint_model_dir is not None
                    and (i + 1) % args.checkpoint_interval == 0):
                transformer.eval()
                ckpt_model_filename = ("CW_" + str(int(args.content_weight)) +
                                       "_lr_" + str(args.lr) + "ckpt_epoch" +
                                       str(e) + "_" + str(i + 1))
                ckpt_model_path = os.path.join(args.checkpoint_model_dir,
                                               ckpt_model_filename)
                flow.save(transformer.state_dict(), ckpt_model_path)
                transformer.train()

    # save model
    transformer.eval()
    save_model_filename = ("CW_" + str(args.content_weight) + "_lr_" +
                           str(args.lr) + "sketch_epoch_" + str(args.epochs) +
                           "_" + str(time.ctime()).replace(" ", "_") + "_" +
                           str(args.content_weight) + "_" +
                           str(args.style_weight))
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    flow.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #12
0
def train(args):
    log(json.dumps({"type": "status_update", "status": "Setting up training"}))
    device = torch.device("cuda" if args.cuda else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])

    log(json.dumps({"type": "status_update", "status": "Loading dataset"}))

    train_dataset = datasets.ImageFolder(args.dataset, transform)
    log(
        json.dumps({
            "type": "dataset_info",
            "dataset_length": len(train_dataset) * args.epochs
        }))
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])

    log(json.dumps({"type": "status_update", "status": "Dataset loaded"}))

    log(json.dumps({"type": "status_update", "status": "Loading image"}))

    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    log(json.dumps({"type": "status_update", "status": "Image loaded"}))

    log(json.dumps({"type": "status_update", "status": "Training setup done"}))

    progress_count = 0

    log(json.dumps({"type": "status_update", "status": "Starting training"}))

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            content_loss = args.content_weight * mse_loss(
                features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            log(
                json.dumps({
                    "type":
                    "training_progress",
                    "progress":
                    str(progress_count),
                    "percent":
                    str(
                        round(
                            progress_count /
                            (len(train_dataset) * args.epochs) * 100, 2))
                }))

            progress_count = progress_count + args.batch_size

            if args.checkpoint_model_dir is not None and (
                    batch_id + 1) % args.checkpoint_interval == 0:
                transformer.eval().cpu()
                if args.name is None:
                    ckpt_model_filename = str(
                        os.path.normpath(os.path.basename(
                            args.style_image))[0:int(
                                os.path.
                                normpath(os.path.basename(args.style_image)).
                                rfind("."))]) + "_" + str(batch_id +
                                                          1) + ".pth"
                else:
                    ckpt_model_filename = str(
                        args.name) + "_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    log(json.dumps({"type": "status_update", "status": "training done"}))

    # save model
    log(json.dumps({"type": "status_update", "status": "saving model"}))
    transformer.eval().cpu()
    if args.name is None:
        save_model_filename = str(
            os.path.normpath(os.path.basename(args.style_image))[0:int(
                os.path.normpath(os.path.basename(args.style_image)).rfind(".")
            )]) + ".pth"
    else:
        save_model_filename = str(args.name + ".pth")

    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    log(json.dumps({"type": "status_update", "status": "model saved"}))
Example #13
0
def train(args):
    """
    Trains the models
    :param args: parameters
    :return: saves the model and checkpoints
    """
    device = torch.device("cuda" if args.cuda else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    image_transform = transforms.Compose([
        transforms.Resize(
            args.image_size),  # the shorter side is resize to match image_size
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),  # to tensor [0,1]
        transforms.Lambda(lambda x: x.mul(255))  # convert back to [0, 255]
    ])
    train_dataset = datasets.ImageFolder(args.dataset, image_transform)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True)  # to provide a batch loader

    style_image = [f for f in os.listdir(args.style_image)]
    style_num = len(style_image)
    print(style_num)

    transformer = TransformerNet(style_number=style_num).to(device)
    adam_optimizer = Adam(transformer.parameters(), learning_rate)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose([
        transforms.Resize(args.style_size),
        transforms.CenterCrop(args.style_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])

    style_batch = []

    for i in range(style_num):
        if ".ipynb" not in style_image[i]:
            style = utils.load_image(args.style_image + style_image[i],
                                     size=args.style_size)
            style = style_transform(style)
            print(style.shape, style_image[i])
            style_batch.append(style)

    style = torch.stack(style_batch).to(device)
    # print("After stack")
    features_style = vgg(utils.normalize_batch(style))
    # print("After feature style")
    gram_style = [utils.gram_matrix(y) for y in features_style]
    # print("starting epochs")
    for e in range(args.epochs):
        with open('/home/sbanda/Fall20-DL-CG/Project3/log.txt', 'a') as reader:
            reader.write("Epoch " + str(e) + ":->\n")
        transformer.train()
        aggregate_content_loss = 0.
        aggregate_style_loss = 0.
        counter = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            print(batch_id)
            if n_batch < args.batch_size:
                break

            counter += n_batch
            # Initialize gradients to zero
            adam_optimizer.zero_grad()

            batch_style_id = [
                i % style_num for i in range(counter - n_batch, counter)
            ]
            y = transformer(x.to(device), style_id=batch_style_id)

            x = utils.normalize_batch(x)
            y = utils.normalize_batch(y)

            features_x = vgg(x.to(device))
            features_y = vgg(y.to(device))

            content_loss = content_weight * mse_loss(features_y.relu2_2,
                                                     features_x.relu2_2)

            style_loss = 0.
            for feature_y, gm_style in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(feature_y)
                style_loss += mse_loss(gm_y, gm_style[batch_style_id, :, :])

            style_loss *= style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            adam_optimizer.step()

            aggregate_content_loss += content_loss.item()
            aggregate_style_loss += style_loss.item()

            if (batch_id + 1) % log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, counter, len(train_dataset),
                    aggregate_content_loss / (batch_id + 1),
                    aggregate_style_loss / (batch_id + 1),
                    (aggregate_content_loss + aggregate_style_loss) /
                    (batch_id + 1))
                with open('/home/sbanda/Fall20-DL-CG/Project3/log.txt',
                          'a') as reader:
                    reader.write(mesg + "\n")
                print(mesg)

            if args.checkpoint_model_dir is not None and (
                    batch_id + 1) % checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(
                    e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(
        time.ctime()).replace(' ', '_').replace(':', '') + "_" + str(
            int(content_weight)) + "_" + str(int(style_weight)) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #14
0
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        kwargs = {'num_workers': 0, 'pin_memory': False}
    else:
        kwargs = {}

    training_set = np.loadtxt(args.dataset, dtype=np.float32)
    training_set_size = training_set.shape[1]
    num_batch = int(training_set_size / args.batch_size)

    transformer = TransformerNet()
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16()
    utils.init_vgg16(args.vgg_model_dir)
    vgg.load_state_dict(
        torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight")))

    if args.cuda:
        transformer.cuda()
        vgg.cuda()

    style = np.loadtxt(args.style_image, dtype=np.float32)
    style = style.reshape((1, 1, args.style_size_x, args.style_size_y))
    style = torch.from_numpy(style)
    style = style.repeat(args.batch_size, 3, 1, 1)
    if args.cuda:
        style = style.cuda()
    style_v = Variable(style, volatile=True)
    style_v = utils.subtract_imagenet_mean_batch(style_v)
    features_style = vgg(style_v)
    gram_style = [utils.gram_matrix(y) for y in features_style]

    # Hard data
    if args.hard_data:
        hard_data = np.loadtxt(args.hard_data_file)
        # if not isinstance(hard_data[0], list):
        #     hard_data = [hard_data]

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        # for batch_id, (x, _) in enumerate(train_loader):
        for batch_id in range(num_batch):
            x = training_set[:, batch_id * args.batch_size:(batch_id + 1) *
                             args.batch_size]
            n_batch = x.shape[1]
            count += n_batch
            x = x.transpose()
            x = x.reshape((n_batch, 1, args.image_size_x, args.image_size_y))

            # plt.imshow(x[0,:,:,:].squeeze(0))
            # plt.show()
            x = torch.from_numpy(x).float()

            optimizer.zero_grad()

            x = Variable(x)
            if args.cuda:
                x = x.cuda()

            y = transformer(x)

            if args.hard_data:
                hard_data_loss = 0
                num_hard_data = 0
                for hd in hard_data:
                    hard_data_loss += args.hard_data_weight * (
                        y[:, 0, hd[1], hd[0]] -
                        hd[2] * 255.0).norm()**2 / n_batch
                    num_hard_data += 1
                hard_data_loss /= num_hard_data

            y = y.repeat(1, 3, 1, 1)
            # x = Variable(utils.preprocess_batch(x))

            # xc = x.data.clone()
            # xc = xc.repeat(1, 3, 1, 1)
            # xc = Variable(xc, volatile=True)

            y = utils.subtract_imagenet_mean_batch(y)
            # xc = utils.subtract_imagenet_mean_batch(xc)

            features_y = vgg(y)
            # features_xc = vgg(xc)

            # f_xc_c = Variable(features_xc[1].data, requires_grad=False)

            # content_loss = args.content_weight * mse_loss(features_y[1], f_xc_c)

            style_loss = 0.
            for m in range(len(features_y)):
                gram_s = Variable(gram_style[m].data, requires_grad=False)
                gram_y = utils.gram_matrix(features_y[m])
                style_loss += args.style_weight * mse_loss(
                    gram_y, gram_s[:n_batch, :, :])

            # total_loss = content_loss + style_loss

            total_loss = style_loss

            if args.hard_data:
                total_loss += hard_data_loss

            total_loss.backward()
            optimizer.step()

            # agg_content_loss += content_loss.data[0]
            agg_style_loss += style_loss.data[0]

            if (batch_id + 1) % args.log_interval == 0:
                if args.hard_data:
                    mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\thard_data: {:.6f}\ttotal: {:.6f}".format(
                        time.ctime(), e + 1, count, num_batch,
                        agg_content_loss / (batch_id + 1),
                        agg_style_loss / (batch_id + 1),
                        hard_data_loss.data[0],
                        (agg_content_loss + agg_style_loss) / (batch_id + 1))
                else:
                    mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                        time.ctime(), e + 1, count, num_batch,
                        agg_content_loss / (batch_id + 1),
                        agg_style_loss / (batch_id + 1),
                        (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)

    # save model
    transformer.eval()
    transformer.cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + "_" + str(
            args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #15
0
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              num_workers=0,
                              pin_memory=True)

    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        print("Number of Data : {}".format(len(train_dataset)))
        print("Number of Batch : {}".format(len(train_loader)))
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        agg_tv_loss = 0.
        count = 0
        for batch_id, (x, _) in tqdm(enumerate(train_loader)):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            # Content Loss
            content_loss = args.content_weight * mse_loss(
                features_y.relu2_2, features_x.relu2_2)

            # Style Loss
            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight

            # Total Variance Loss
            tv_loss = 1e-7 * (
                torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])) +
                torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :])))

            total_loss = content_loss + style_loss + tv_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()
            agg_tv_loss += tv_loss.item()

            if (batch_id + 1) % args.log_interval == 0:
                ctime = datetime.today().strftime('%Y.%m.%d %H:%M')
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttv: {:.6f}\ttotal: {:.6f}".format(
                    ctime, e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    agg_tv_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss + agg_tv_loss) /
                    (batch_id + 1))
                print(mesg)

            if args.checkpoint_model_dir is not None and (
                    batch_id + 1) % args.checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(
                    e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    curr_time = datetime.today().strftime("%Y%m%d_%H%M")
    save_model_filename = "epoch_" + str(
        args.epochs) + "_" + curr_time + ".pth"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
def main(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # DATA
    # Transform and Dataloader for COCO dataset
    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),  # / 255.
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    # MODEL
    # Define Image Transformation Network with MSE loss and Adam optimizer
    transformer = TransformerNet().to(device)
    mse_loss = nn.MSELoss()
    optimizer = optim.Adam(transformer.parameters(), args.learning_rate)

    # Pretrained VGG
    vgg = VGG16(requires_grad=False).to(device)

    # FEATURES
    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])

    # Load the style image
    style = Image.open(args.style)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    # Compute the style features
    features_style = vgg(normalize_batch(style))

    # Loop through VGG style layers to calculate Gram Matrix
    gram_style = [gram_matrix(y) for y in features_style]

    # TRAIN
    for epoch in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.

        for batch_id, (x, _) in tqdm(enumerate(train_loader), unit='batch'):
            x = x.to(device)
            n_batch = len(x)

            optimizer.zero_grad()

            # Parse throught Image Transformation network
            y = transformer(x)
            y = normalize_batch(y)
            x = normalize_batch(x)

            # Parse through VGG layers
            features_y = vgg(y)
            features_x = vgg(x)

            # Calculate content loss
            content_loss = args.content_weight * mse_loss(
                features_y.relu2_2, features_x.relu2_2)

            # Calculate style loss
            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            # Monitor
            if (batch_id + 1) % args.log_interval == 0:
                tqdm.write('[{}] ({})\t'
                           'content: {:.6f}\t'
                           'style: {:.6f}\t'
                           'total: {:.6f}'.format(
                               epoch + 1, batch_id + 1,
                               agg_content_loss / (batch_id + 1),
                               agg_style_loss / (batch_id + 1),
                               (agg_content_loss + agg_style_loss) /
                               (batch_id + 1)))

            # Checkpoint
            if (batch_id + 1) % args.save_interval == 0:
                # eval mode
                transformer.eval().cpu()
                style_name = args.style.split('/')[-1].split('.')[0]
                checkpoint_file = os.path.join(args.checkpoint_dir,
                                               '{}.pth'.format(style_name))

                tqdm.write('Checkpoint {}'.format(checkpoint_file))
                torch.save(transformer.state_dict(), checkpoint_file)

                # back to train mode
                transformer.to(device).train()
Example #17
0
def run_train(args):

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    print('running training processing...')

    style_image = load_image(args.style_image,
                             mask=False,
                             size=args.image_style_size,
                             scale=args.style_scale,
                             square=True)
    style_image = preprocess(style_image)

    #save_image('style_image.png',style_image)

    cnn = None
    if args.loss_model == 'vgg19':
        cnn = models.vgg19(pretrained=True).features.to(device).eval()
    elif args.loss_model == 'vgg16':
        cnn = models.vgg16(pretrained=True).features.to(device).eval()

    # get tranform net, content losses, style losses
    loss_net, content_losses, style_losses, tv_loss = build_loss_model(
        cnn, args, style_image)
    #print(loss_net)

    #collect space back
    cnn = None
    del cnn

    if args.backend == 'cudnn':
        torch.backends.cudnn.enabled = True

    #this is to define the inchannels of transferm_net
    in_channels = 3

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
    ])

    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    transform_net = TransformerNet(in_channels).to(device)
    mse_loss = nn.MSELoss()

    optimizer = optim.Adam(transform_net.parameters(), lr=args.learning_rate)

    iteration = [0]
    while iteration[0] <= args.epochs - 1:
        transform_net.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            stloss = 0.
            ctloss = 0.
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            #stack color_content_masks into x as input
            x, x_ori = x.to(device), x.to(device).clone()
            x = preprocess(x)
            x_ori = preprocess(x_ori)

            #save_image('content_x1.png', x[3].unsqueeze(0))
            #assert 0 == 1

            #forward input to transform_net
            y = transform_net(x)

            #compute pixel loss
            pixloss = 0.
            if args.pixel_weight > 0:
                pixloss = mse_loss(x, y) * args.pixel_weight

            #compute content loss and style loss

            for ctl in content_losses:
                ctl.mode = 'capture'
            for stl in style_losses:
                stl.mode = 'None'
            loss_net(x_ori)

            for ctl in content_losses:
                ctl.mode = 'loss'
            for stl in style_losses:
                stl.mode = 'loss'
            loss_net(y)
            for ctl in content_losses:
                ctloss += mse_loss(ctl.target, ctl.input) * args.content_weight
            for stl in style_losses:
                local_G = gram_matrix(stl.input)
                stloss += mse_loss(local_G, stl.target) * args.style_weight
            if tv_loss is not None:
                tvloss = tv_loss.loss
            else:
                tvloss = 0.

            loss = ctloss + stloss + pixloss  #+ tvloss

            loss.backward()
            optimizer.step()

            agg_content_loss += ctloss.item()
            agg_style_loss += stloss.item()

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}, Epoch {}:\t[{}/{}], content: {:.6f}, style: {:.6f}, total: {:.6f}".format(
                    time.ctime(), iteration[0], count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)
            if args.checkpoint_model_dir is not None and (
                    batch_id + 1) % args.checkpoint_interval == 0:
                transform_net.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(
                    iteration[0] + 1) + "_batch_id_" + str(batch_id +
                                                           1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transform_net.state_dict(), ckpt_model_path)
                transform_net.to(device).train()

        iteration[0] += 1

    #save final model
    transform_net.eval().cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + "_content_" + str(
            args.content_weight) + "_style_" + str(
                args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transform_net.state_dict(), save_model_path)

    print("\n training process is Done!, trained model saved at",
          save_model_path)
import cv2
import numpy as np
import torch
from imutils import paths
from transformer_net import TransformerNet
from PIL import Image
from torchvision import transforms

model = TransformerNet()
model.load_state_dict(torch.load('checkpoints/GodBearer.pth'))
model.cuda()
model.eval()

trm = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Lambda(lambda x: x * 255)])

cap = cv2.VideoCapture(0)
while (True):
    success, img = cap.read()
    img = Image.fromarray(img).resize((512, 512))
    img = np.array(img)
    cv2.imshow("before", img)

    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = Image.fromarray(img).resize((256, 256))

    img = trm(img).cuda()
    t_img = model(img.unsqueeze(0)).squeeze(0).cpu()

    t_img /= 255
Example #19
0
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    print("Loading data")
    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    print "Building the model"
    transformer = TransformerNet()
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False)
    style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1)

    if args.cuda:
        transformer.cuda()
        vgg.cuda()
        style = style.cuda()

    style_v = Variable(style)
    style_v = utils.normalize_batch(style_v)
    features_style = vgg(style_v)
    gram_style = [utils.gram_matrix(y) for y in features_style]

    def multiply(loss, weight):
        return loss * weight

    def add(loss1, loss2):
        return loss1 + loss2

    metrics_names = ['Content Loss', 'Style Loss', 'Total Loss']
    with missinglink_project.create_experiment(
        transformer,
        display_name='Style Transfer PyTorch',
        optimizer=optimizer,
        train_data_object=train_loader,
        metrics={metrics_names[0]: multiply, metrics_names[1]: multiply, metrics_names[2]: add}
    ) as experiment:
        (wrapped_content_loss,
         wrapped_style_loss,
         wrapped_total_loss) = [experiment.metrics[metric_name] for metric_name in metrics_names]

        print("Starting to train")
        for e in experiment.epoch_loop(args.epochs):
            transformer.train()
            agg_content_loss = 0.
            agg_style_loss = 0.
            count = 0
            for batch_id, (x, _) in experiment.batch_loop(iterable=train_loader):
                n_batch = len(x)
                count += n_batch
                optimizer.zero_grad()
                x = Variable(x)
                if args.cuda:
                    x = x.cuda()

                y = transformer(x)

                y = utils.normalize_batch(y)
                x = utils.normalize_batch(x)

                features_y = vgg(y)
                features_x = vgg(x)

                content_loss = mse_loss(features_y.relu2_2, features_x.relu2_2)
                content_loss = wrapped_content_loss(content_loss, args.content_weight)

                style_loss = 0.
                for ft_y, gm_s in zip(features_y, gram_style):
                    gm_y = utils.gram_matrix(ft_y)
                    style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
                style_loss = wrapped_style_loss(style_loss, args.style_weight)

                total_loss = wrapped_total_loss(content_loss, style_loss)
                total_loss.backward()
                optimizer.step()

                agg_content_loss += content_loss.data[0]
                agg_style_loss += style_loss.data[0]

                if (batch_id + 1) % args.log_interval == 0:
                    mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                        time.ctime(), e + 1, count, len(train_dataset),
                                      agg_content_loss / (batch_id + 1),
                                      agg_style_loss / (batch_id + 1),
                                      (agg_content_loss + agg_style_loss) / (batch_id + 1)
                    )
                    print(mesg)

                if args.checkpoint_model_dir is not None and (batch_id + 1) % args.checkpoint_interval == 0:
                    transformer.eval()
                    if args.cuda:
                        transformer.cpu()
                    ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                    ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename)
                    torch.save(transformer.state_dict(), ckpt_model_path)
                    if args.cuda:
                        transformer.cuda()
                    transformer.train()

        # save model
        transformer.eval()
        if args.cuda:
            transformer.cpu()
        save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
            args.content_weight) + "_" + str(args.style_weight) + ".model"
        save_model_path = os.path.join(args.save_model_dir, save_model_filename)
        torch.save(transformer.state_dict(), save_model_path)

        print("\nDone, trained model saved at", save_model_path)
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        # utils.RGB2LAB(),
        transforms.ToTensor(),
        # utils.LAB2Tensor(),
    ])
    pert_transform = transforms.Compose([utils.ColorPerturb()])
    trainset = utils.FlatImageFolder(args.dataset, transform, pert_transform)
    trainloader = DataLoader(trainset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=4)
    model = TransformerNet()
    if args.gpus is not None:
        model = nn.DataParallel(model, device_ids=args.gpus)
    else:
        model = nn.DataParallel(model)
    if args.resume:
        state_dict = torch.load(args.resume)
        model.load_state_dict(state_dict)

    if args.cuda:
        model.cuda()

    optimizer = torch.optim.Adam(model.parameters(), args.lr)
    criterion = nn.MSELoss()

    start_time = datetime.now()

    for e in range(args.epochs):
        model.train()
        count = 0
        acc_loss = 0.0
        for batchi, (pert_img, ori_img) in enumerate(trainloader):
            count += len(pert_img)
            if args.cuda:
                pert_img = pert_img.cuda(non_blocking=True)
                ori_img = ori_img.cuda(non_blocking=True)

            optimizer.zero_grad()

            rec_img = model(pert_img)
            loss = criterion(rec_img, ori_img)
            loss.backward()
            optimizer.step()

            acc_loss += loss.item()
            if (batchi + 1) % args.log_interval == 0:
                mesg = '{}\tEpoch {}: [{}/{}]\ttotal loss: {:.6f}'.format(
                    time.ctime(), e + 1, count, len(trainset),
                    acc_loss / (args.log_interval))
                print(mesg)
                acc_loss = 0.0

        if args.checkpoint_dir and e + 1 != args.epochs:
            model.eval().cpu()
            ckpt_filename = 'ckpt_epoch_' + str(e + 1) + '.pth'
            ckpt_path = osp.join(args.checkpoint_dir, ckpt_filename)
            torch.save(model.state_dict(), ckpt_path)
            model.cuda().train()
            print('Checkpoint model at epoch %d saved' % (e + 1))

    model.eval().cpu()
    if args.save_model_name:
        model_filename = args.save_model_name
    else:
        model_filename = "epoch_" + str(args.epochs) + "_" + str(
            time.ctime()).replace(' ', '_') + ".model"
    model_path = osp.join(args.save_model_dir, model_filename)
    torch.save(model.state_dict(), model_path)

    end_time = datetime.now()

    print('Finished training after %s, trained model saved at %s' %
          (end_time - start_time, model_path))
Example #21
0
def NeuralStyle_init(weight_path, alpha):
    model = TransformerNet(alpha)
    model.load_state_dict(torch.load(weight_path))
    model.cuda()
    model.eval()
    return model
Example #22
0
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        kwargs = {'num_workers': 0, 'pin_memory': False}
    else:
        kwargs = {}

    transform = transforms.Compose([
        transforms.Scale(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              **kwargs)

    transformer = TransformerNet()
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16()
    utils.init_vgg16(args.vgg_model_dir)
    vgg.load_state_dict(
        torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight")))

    if args.cuda:
        transformer.cuda()
        vgg.cuda()

    style = utils.tensor_load_rgbimage(args.style_image, size=args.style_size)
    style = style.repeat(args.batch_size, 1, 1, 1)
    style = utils.preprocess_batch(style)
    if args.cuda:
        style = style.cuda()
    style_v = Variable(style, volatile=True)
    style_v = utils.subtract_imagenet_mean_batch(style_v)
    features_style = vgg(style_v)
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()
            x = Variable(utils.preprocess_batch(x))
            if args.cuda:
                x = x.cuda()

            y = transformer(x)

            xc = Variable(x.data.clone(), volatile=True)

            y = utils.subtract_imagenet_mean_batch(y)
            xc = utils.subtract_imagenet_mean_batch(xc)

            features_y = vgg(y)
            features_xc = vgg(xc)

            f_xc_c = Variable(features_xc[1].data, requires_grad=False)

            content_loss = args.content_weight * mse_loss(
                features_y[1], f_xc_c)

            style_loss = 0.
            for m in range(len(features_y)):
                gram_s = Variable(gram_style[m].data, requires_grad=False)
                gram_y = utils.gram_matrix(features_y[m])
                style_loss += args.style_weight * mse_loss(
                    gram_y, gram_s[:n_batch, :, :])

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.data[0]
            agg_style_loss += style_loss.data[0]

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)

    # save model
    transformer.eval()
    transformer.cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + "_" + str(
            args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #23
0
def train(args):
    if not os.path.exists(args.save_model_dir):
        os.makedirs(args.save_model_dir)


    device = torch.device("cuda" if args.is_cuda else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])

    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    style = utils.load_image(args.style_image, size=args.style_size)
    # print(style.size)
    # ss('yo')
    style = style_transform(style)  # it's not transform
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)
    # style = style.repeat(2,1,1,1).to(device)
    # print(style.shape)
    # print()
    # ss('ho')
    features_style = vgg(utils.normalize_batch(style))
    # print(features_style.relu4_3.shape)
    # for i in features_style:
    #     print(i.shape)
    # ss('normalize')
    gram_style = [utils.gram_matrix(y) for y in features_style]
    # for i in gram_style:
        # print(i.shape)
    # ss('main: gram style')
    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            # print(n_batch)
            # ss('hi')
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            # print(x.shape)
            # print(x[0,0,0,:])
            # ss('in epoch, batch')
            y = transformer(x)
            # ss('in epoch, batch')
            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            content_loss = args.content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            # if (batch_id + 1) % args.log_interval == 0:
            if True:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                                  agg_content_loss / (batch_id + 1),
                                  agg_style_loss / (batch_id + 1),
                                  (agg_content_loss + agg_style_loss) / (batch_id + 1)
                )
                print(mesg)
            if args.is_quickrun:
                if count > 10:
                    break
            # if args.checkpoint_model_dir is not None and (batch_id + 1) % args.checkpoint_interval == 0:
            #     transformer.eval().cpu()
            #     ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth"
            #     ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename)
            #     torch.save(transformer.state_dict(), ckpt_model_path)
            #     transformer.to(device).train()
            if (e % 50 == 0) or (e>400 and e % 10 ==0):
                # utils.save_image(args.save_model_dir+'/imgs/npepoch_{}.png'.format(e), y[0].detach().cpu())
                # torchvision.utils.save_image(y, './imgs/epoch_{}.png'.format(e), normalize=True)
                torchvision.utils.save_image(y, './imgs/before/epoch_{}.png'.format(e), normalize=True)
                y = y.clamp(0, 255)
                torchvision.utils.save_image(y, './imgs/non/epoch_{}.png'.format(e))
                torchvision.utils.save_image(y, './imgs/after/epoch_{}.png'.format(e), normalize=True)
            # ss('yo')
    # save model
    transformer.eval().cpu()

    save_model_filename = "style_"+args.style_name+"_epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
        args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #24
0
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        kwargs = {'num_workers': 12, 'pin_memory': False}
    else:
        kwargs = {}
    from transform.color_op import Linearize, SRGB2XYZ, XYZ2CIE

    RGB2YUV = transforms.Compose([
        Linearize(),
        SRGB2XYZ(),
        XYZ2CIE()
    ])

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        RGB2YUV(),
        transforms.ToTensor(),
        # transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, **kwargs)

    transformer = TransformerNet(in_channels=2, out_channels=1)  # input: LS, predict: M
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    transformer = nn.DataParallel(transformer)

    if args.cuda:
        if not torch.cuda.is_available():
            raise RuntimeError("CUDA is requested, but related driver/device is not set properly.")
        transformer.cuda()

    for e in range(args.epochs):
        transformer.train()
        # agg_content_loss = 0.
        # agg_style_loss = 0.
        count = 0
        for batch_id, (imgs, _) in enumerate(train_loader):
            n_batch = len(imgs)
            count += n_batch
            optimizer.zero_grad()
            # First channel
            x = torch.cat([imgs[:, :1, :, :].clone(), imgs[:, -1:, :, :].clone()], dim=1)
            # Second and third channels
            gt = imgs[:, 1:2, :, :].clone()

            if args.cuda:
                x = x.cuda()
                gt = gt.cuda()

            y = transformer(x)

            total_loss = mse_loss(y, gt)
            total_loss.backward()
            optimizer.step()

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                                  total_loss / (batch_id + 1)
                )
                print(mesg)

    # save model
    transformer.eval()
    transformer.cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
        args.content_weight) + "_" + str(args.style_weight) + ".model"
    os.makedirs(args.save_model_dir, exist_ok=True)
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #25
0
def train(args):
    # make sure each time we train, if args.seed stays the same, then
    # the random number we get is same as last time we train.
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))  # 0-1 to 0-255
    ])
    # note the order: give where the images at; load the images and transform; give the batch size
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    # TODO: in transformernet
    transformer = TransformerNet()
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    # TODO: relus in vgg16
    vgg = Vgg16(requires_grad=False)

    style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    style = utils.load_image(args.style_image, size=args.style_size)
    # style2 = utils.load_image(args.style_image2, size=args.style_size)
    style = style_transform(style)
    # style2 = style_transform(style2)

    # repeat the style tensor 4 times
    style = style.repeat(args.batch_size, 1, 1, 1)
    # style2 = style2.repeat(args.batch_size, 1, 1, 1)

    if args.cuda:
        transformer.cuda()
        vgg.cuda()
        style = style.cuda()
        # style2 = style2.cuda()


    style_v = Variable(style)
    style_v = utils.normalize_batch(style_v)
    features_style = vgg(style_v)
    # style_v2 = Variable(style2)
    # style_v2 = utils.normalize_batch(style_v2)
    # features_style2 = vgg(style_v2)
    # to determine style loss, make use of gram matrix
    gram_style = [utils.gram_matrix(y) for y in features_style]
    # gram_style2 = [utils.gram_matrix(y) for y in features_style2]


    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()  # pytorch accumulates gradients, making them zero for each minibatch
            x = Variable(x)
            if args.cuda:
                x = x.cuda()

            # forward pass
            y = transformer(x)  # after transformer - y

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            # TODO: mse_loss of which relu could be modified
            content_loss = args.content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
                # style_loss += mse_loss(gm_y, gm_s2[:n_batch, :, :])

            style_loss *= args.style_weight

            total_loss = content_loss + style_loss

            # backward pass
            total_loss.backward()  # this simply computes the gradients for each learnable parameters

            # update weights
            optimizer.step()

            agg_content_loss += content_loss.data[0]
            agg_style_loss += style_loss.data[0]

            if (batch_id + 1) % args.log_interval == 0:
                msg = "Epoch "+str(e + 1)+" "+str(count)+"/"+str(len(train_dataset))
                msg += " content loss : "+str(agg_content_loss / (batch_id + 1))
                msg += " style loss : " +str(agg_style_loss / (batch_id + 1))
                msg += " total loss : " +str((agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(msg)

            if args.checkpoint_model_dir is not None and (batch_id + 1) % args.checkpoint_interval == 0:
                transformer.eval()
                if args.cuda:
                    transformer.cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                if args.cuda:
                    transformer.cuda()
                transformer.train()

    # save model
    transformer.eval()
    if args.cuda:
        transformer.cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
        args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #26
0
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            content_loss = args.content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                                  agg_content_loss / (batch_id + 1),
                                  agg_style_loss / (batch_id + 1),
                                  (agg_content_loss + agg_style_loss) / (batch_id + 1)
                )
                print(mesg)

            if args.checkpoint_model_dir is not None and (batch_id + 1) % args.checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
        args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #27
0
        for ft_y, gm_s in zip(features_y, gram_style):
            gm_y = utils.gram_matrix(ft_y)
            style_loss = style_loss + mse_loss(gm_y, gm_s[:n_batch, :, :])

        style_loss = style_loss * style_weight

        total_loss = content_loss + style_loss
        total_loss.backward()
        optimizer.step()

        agg_content_loss = agg_content_loss + content_loss.item()
        agg_style_loss = agg_style_loss + style_loss.item()

        if ((batch_id + 1) % log_interval == 0):
            message = '{}\tEpoch {}:\t[{}/{}]\tContent Loss: {:.2f}\tStyle Loss: {:.2f}\tTotal Loss: {:.2f}'.format(
                time.ctime(), epoch + 1, count, len(train_dataset),
                agg_content_loss / (batch_id + 1),
                agg_style_loss / (batch_id + 1),
                (agg_content_loss + agg_style_loss) / (batch_id + 1))

            print(message)

transformer.eval().cpu()
#weights_filename = 'candy.pth'
weights_filename = 'skyscraper_single.pth'
#weights_filename = 'ocean_single.pth'
#weights_filename = 'hot-spring.pth'
#weights_filename = 'desert-sand.pth'
save_model_dir = save_model_dir + weights_filename
torch.save(transformer.state_dict(), save_model_dir)