def evaluate(args):
    content_image = utils.tensor_load_rgbimage(args.content_image,
                                               size=args.content_size,
                                               keep_asp=True)
    content_image = content_image.unsqueeze(0)
    style = utils.tensor_load_rgbimage(args.style_image, size=args.style_size)
    style = style.unsqueeze(0)
    style = utils.preprocess_batch(style)

    vgg = Vgg16()
    utils.init_vgg16(args.vgg_model_dir)
    vgg.load_state_dict(
        torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight")))

    style_model = HangSNetV1()
    style_model.load_state_dict(torch.load(args.model))

    if args.cuda:
        style_model.cuda()
        vgg.cuda()
        content_image = content_image.cuda()
        style = style.cuda()

    style_v = Variable(style, volatile=True)
    utils.subtract_imagenet_mean_batch(style_v)
    features_style = vgg(style_v)
    gram_style = [utils.gram_matrix(y) for y in features_style]

    content_image = Variable(utils.preprocess_batch(content_image))
    target = Variable(gram_style[2].data, requires_grad=False)
    style_model.setTarget(target)

    output = style_model(content_image)
    utils.tensor_save_bgrimage(output.data[0], args.output_image, args.cuda)
def optimize(args):
    """    Gatys et al. CVPR 2017
    ref: Image Style Transfer Using Convolutional Neural Networks
    """
    # load the content and style target
    content_image = utils.tensor_load_rgbimage(args.content_image,
                                               size=args.content_size,
                                               keep_asp=True)
    content_image = content_image.unsqueeze(0)
    content_image = Variable(utils.preprocess_batch(content_image),
                             requires_grad=False)
    content_image = utils.subtract_imagenet_mean_batch(content_image)
    style_image = utils.tensor_load_rgbimage(args.style_image,
                                             size=args.style_size)
    style_image = style_image.unsqueeze(0)
    style_image = Variable(utils.preprocess_batch(style_image),
                           requires_grad=False)
    style_image = utils.subtract_imagenet_mean_batch(style_image)

    # load the pre-trained vgg-16 and extract features
    vgg = Vgg16()
    utils.init_vgg16(args.vgg_model_dir)
    vgg.load_state_dict(
        torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight")))
    if args.cuda:
        content_image = content_image.cuda()
        style_image = style_image.cuda()
        vgg.cuda()
    features_content = vgg(content_image)
    f_xc_c = Variable(features_content[1].data, requires_grad=False)
    features_style = vgg(style_image)
    gram_style = [utils.gram_matrix(y) for y in features_style]
    # init optimizer
    output = Variable(content_image.data, requires_grad=True)
    optimizer = Adam([output], lr=args.lr)
    mse_loss = torch.nn.MSELoss()
    # optimizing the images
    for e in range(args.iters):
        utils.imagenet_clamp_batch(output, 0, 255)
        optimizer.zero_grad()
        features_y = vgg(output)
        content_loss = args.content_weight * mse_loss(features_y[1], f_xc_c)

        style_loss = 0.
        for m in range(len(features_y)):
            gram_y = utils.gram_matrix(features_y[m])
            gram_s = Variable(gram_style[m].data, requires_grad=False)
            style_loss += args.style_weight * mse_loss(gram_y, gram_s)

        total_loss = content_loss + style_loss

        if (e + 1) % args.log_interval == 0:
            print(total_loss.data.cpu().numpy()[0])
        total_loss.backward()

        optimizer.step()
    # save the image
    output = utils.add_imagenet_mean_batch(output)
    utils.tensor_save_bgrimage(output.data[0], args.output_image, args.cuda)
Esempio n. 3
0
def train_ofb(args):
    train_dataset = dataset.DAVISDataset(args.dataset, use_flow=True)
    train_loader = DataLoader(train_dataset, batch_size=1)

    transformer = transformer_net.TransformerNet(args.pad_type)
    transformer.train()
    optimizer = torch.optim.Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16()
    vgg.load_state_dict(
        torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight")))
    vgg.eval()

    if args.cuda:
        transformer.cuda()
    vgg.cuda()
    mse_loss.cuda()

    style = utils.tensor_load_resize(args.style_image, args.style_size)
    style = style.unsqueeze(0)
    print("=> Style image size: " + str(style.size()))
    print("=> Pixel OFB loss weight: %f" % args.time_strength)

    style = utils.preprocess_batch(style)
    if args.cuda: style = style.cuda()
    style = utils.subtract_imagenet_mean_batch(style)
    features_style = vgg(style)
    gram_style = [utils.gram_matrix(y).detach() for y in features_style]

    train_loader.dataset.reset()
    transformer.train()
    transformer.cuda()
    agg_content_loss = agg_style_loss = agg_pixelofb_loss = 0.
    iters = 0
    anormaly = False
    elapsed_time = 0
    for batch_id, (x, flow, conf) in enumerate(tqdm(train_loader)):
        x, flow, conf = x[0], flow[0], conf[0]
        iters += 1

        optimizer.zero_grad()
        x = utils.preprocess_batch(x)  # (N, 3, 256, 256)
        if args.cuda:
            x = x.cuda()
            flow = flow.cuda()
            conf = conf.cuda()
        y = transformer(x)  # (N, 3, 256, 256)

        begin_time = time.time()
        warped_y, warped_y_mask = warp(y[1:], flow)
        warped_y = warped_y.detach()
        warped_y_mask *= conf
        pixel_ofb_loss = args.time_strength * weighted_mse(
            y[:-1], warped_y, warped_y_mask)
        pixel_ofb_loss.backward()
        elapsed_time += time.time() - begin_time
        if batch_id > 1000: break
    print(elapsed_time / float(batch_id + 1))
Esempio n. 4
0
def vectorize(args):
    size = args.size
    # vectors = np.zeros((size, size, 2), dtype=np.float32)
    # for y in range(size):
    #     for x in range(size):
    #         xx = float(x - size / 2)
    #         yy = float(y - size / 2)
    #         rsq = xx ** 2 + yy ** 2
    #         if (rsq == 0):
    #             vectors[y, x, 0] = 1
    #             vectors[y, x, 1] = 1
    #         else:
    #             vectors[y, x, 0] = -yy / rsq
    #             vectors[y, x, 1] = xx / rsq
    # vectors = NormalizVectrs(vectors)

    device = torch.device("cuda" if args.cuda else "cpu")

    content_image = Image.open(args.content_image).convert('L')
    content_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0)
    content_image = utils.subtract_imagenet_mean_batch(content_image)
    content_image = content_image.to(device)

    with torch.no_grad():
        vectorize_model = TransformerNet()
        state_dict = torch.load(args.saved_model)
        # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
        for k in list(state_dict.keys()):
            if re.search(r'in\d+\.running_(mean|var)$', k):
                pdb.set_trace()
                del state_dict[k]
        vectorize_model.load_state_dict(state_dict)
        vectorize_model.to(device)
        output = vectorize_model(content_image)

    target = dataset.hdf5_loader(args.target_vector)
    target_transform = transforms.ToTensor()
    target = target_transform(target)
    target = target.unsqueeze(0).to(device)

    cosine_loss = torch.nn.CosineEmbeddingLoss()
    label = torch.ones(1, 1, args.size, args.size).to(device)
    loss = cosine_loss(output, target, label)
    print(loss.item())

    output = output.cpu().clone().numpy()[0].transpose(1, 2, 0)
    output = NormalizVectrs(output)
    lic(output, "output.jpg")

    target = target.cpu().clone().numpy()[0].transpose(1, 2, 0)
    lic(target, "target.jpg")
Esempio n. 5
0
def optimize(args):
    style_image = utils.tensor_load_rgbimage(args.style_image,
                                             size=args.style_size)
    style_image = style_image.unsqueeze(0)
    style_image = Variable(utils.preprocess_batch(style_image),
                           requires_grad=False)
    style_image = utils.subtract_imagenet_mean_batch(style_image)

    # generate the vector field that we want to backward from
    size = args.content_size
    vectors = np.zeros((size, size, 2), dtype=np.float32)
    vortex_spacing = 0.5
    extra_factor = 2.

    a = np.array([1, 0]) * vortex_spacing
    b = np.array([np.cos(np.pi / 3), np.sin(np.pi / 3)]) * vortex_spacing
    rnv = int(2 * extra_factor / vortex_spacing)
    vortices = [
        n * a + m * b for n in range(-rnv, rnv) for m in range(-rnv, rnv)
    ]
    vortices = [(x, y) for (x, y) in vortices
                if -extra_factor < x < extra_factor
                and -extra_factor < y < extra_factor]

    xs = np.linspace(-1, 1, size).astype(np.float32)[None, :]
    ys = np.linspace(-1, 1, size).astype(np.float32)[:, None]

    for (x, y) in vortices:
        rsq = (xs - x)**2 + (ys - y)**2
        vectors[..., 0] += (ys - y) / rsq
        vectors[..., 1] += -(xs - x) / rsq
    # for y in range(size):
    #     for x in range(size):
    #         xx = float(x - size / 2)
    #         yy = float(y - size / 2)
    #         rsq = xx ** 2 + yy ** 2
    #         if rsq == 0:
    #             vectors[y, x, 0] = 1
    #             vectors[y, x, 1] = 1
    #         else:
    #             vectors[y, x, 0] = -yy / rsq
    #             vectors[y, x, 1] = xx / rsq
    #         # vectors[y, x, 0] = 1
    #         # vectors[y, x, 1] = -1
    vectors = utils.tensor_load_vector_field(vectors)

    # load the pre-trained vgg-16 and extract features
    vgg = Vgg16()
    utils.init_vgg16(args.vgg_model_dir)
    vgg.load_state_dict(
        torch.load(os.path.join(args.vgg_model_dir, 'vgg16.weight')))
    if args.cuda:
        style_image = style_image.cuda()
        vgg.cuda()
    features_style = vgg(style_image)
    gram_style = [utils.gram_matrix(y) for y in features_style]

    # load the sobel network
    sobel = Sobel()
    if args.cuda:
        vectors = vectors.cuda()
        sobel.cuda()

    # init optimizer
    vectors_size = vectors.data.size()
    output_size = np.asarray(vectors_size)
    output_size[1] = 3
    output_size = torch.Size(output_size)
    output = Variable(torch.randn(output_size, device="cuda") * 30,
                      requires_grad=True)
    optimizer = Adam([output], lr=args.lr)
    cosine_loss = CosineLoss()
    mse_loss = torch.nn.MSELoss()

    #optimize the images
    tbar = trange(args.iters)
    for e in tbar:
        utils.imagenet_clamp_batch(output, 0, 255)
        optimizer.zero_grad()
        sobel_input = utils.gray_bgr_batch(output)
        sobel_y = sobel(sobel_input)
        content_loss = args.content_weight * cosine_loss(vectors, sobel_y)

        vgg_input = output
        features_y = vgg(vgg_input)
        style_loss = 0
        for m in range(len(features_y)):
            gram_y = utils.gram_matrix(features_y[m])
            gram_s = Variable(gram_style[m].data, requires_grad=False)
            style_loss += args.style_weight * mse_loss(gram_y, gram_s)

        total_loss = content_loss + style_loss
        total_loss.backward()
        optimizer.step()
        if ((e + 1) % args.log_interval == 0):
            print("iter: %d content_loss: %f style_loss %f" %
                  (e, content_loss.item() / args.content_weight,
                   style_loss.item() / args.style_weight))
        tbar.set_description(str(total_loss.data.cpu().numpy().item()))

    # save the image
    output = utils.add_imagenet_mean_batch_device(output, args.cuda)
    utils.tensor_save_bgrimage(output.data[0], args.output_image, args.cuda)
Esempio n. 6
0
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    target_transform = transforms.ToTensor()

    train_dataset = VFDataset(args.dataset, transform, target_transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    transformer = TransformerNet().to(device)
    if args.load_model is not None:
        transformer.load_state_dict(torch.load(args.load_model))
    optimizer = Adam(transformer.parameters(), args.lr)
    # mse_loss = torch.nn.MSELoss()
    cosine_loss = torch.nn.CosineEmbeddingLoss()
    label = torch.ones(args.batch_size, 1, args.image_size,
                       args.image_size).to(device)

    # log_file = open(args.log_file, "w")

    for e in range(args.epochs):
        transformer.train()
        agg_loss = 0.
        count = 0
        for batch_id, (x, vf) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = utils.subtract_imagenet_mean_batch(x)
            x = x.to(device)
            y = transformer(x)
            vf = vf.to(device)

            # loss = mse_loss(y, vf)
            loss = cosine_loss(y, vf, label)
            loss.backward()
            optimizer.step()

            agg_loss += loss.item()

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_loss / (batch_id + 1))
                print(mesg)

            if args.checkpoint_model_dir is not None and (
                    batch_id + 1) % args.checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(
                    e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Esempio n. 7
0
def train(args):
    check_paths(args)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        kwargs = {'num_workers': 0, 'pin_memory': False}
    else:
        kwargs = {}

    transform = transforms.Compose([
        transforms.Scale(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              **kwargs)

    style_model = Net(ngf=args.ngf)
    if args.resume is not None:
        print('Resuming, initializing using weight from {}.'.format(
            args.resume))
        style_model.load_state_dict(torch.load(args.resume))
    print(style_model)
    optimizer = Adam(style_model.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16()
    utils.init_vgg16(args.vgg_model_dir)
    vgg.load_state_dict(
        torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight")))

    if args.cuda:
        style_model.cuda()
        vgg.cuda()

    style_loader = utils.StyleLoader(args.style_folder, args.style_size)

    tbar = trange(args.epochs)
    for e in tbar:
        style_model.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()
            x = Variable(utils.preprocess_batch(x))
            if args.cuda:
                x = x.cuda()

            style_v = style_loader.get(batch_id)
            style_model.setTarget(style_v)

            style_v = utils.subtract_imagenet_mean_batch(style_v)
            features_style = vgg(style_v)
            gram_style = [utils.gram_matrix(y) for y in features_style]

            y = style_model(x)
            xc = Variable(x.data.clone())

            y = utils.subtract_imagenet_mean_batch(y)
            xc = utils.subtract_imagenet_mean_batch(xc)

            features_y = vgg(y)
            features_xc = vgg(xc)

            f_xc_c = Variable(features_xc[1].data, requires_grad=False)

            content_loss = args.content_weight * mse_loss(
                features_y[1], f_xc_c)

            style_loss = 0.
            for m in range(len(features_y)):
                gram_y = utils.gram_matrix(features_y[m])
                gram_s = Variable(gram_style[m].data,
                                  requires_grad=False).repeat(
                                      args.batch_size, 1, 1, 1)
                style_loss += args.style_weight * mse_loss(
                    gram_y, gram_s[:n_batch, :, :])

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.data[0]
            agg_style_loss += style_loss.data[0]

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                tbar.set_description(mesg)

            if (batch_id + 1) % (4 * args.log_interval) == 0:
                # save model
                style_model.eval()
                style_model.cpu()
                save_model_filename = "Epoch_" + str(e) + "iters_" + str(count) + "_" + \
                    str(time.ctime()).replace(' ', '_') + "_" + str(
                    args.content_weight) + "_" + str(args.style_weight) + ".model"
                save_model_path = os.path.join(args.save_model_dir,
                                               save_model_filename)
                torch.save(style_model.state_dict(), save_model_path)
                style_model.train()
                style_model.cuda()
                tbar.set_description("\nCheckpoint, trained model saved at",
                                     save_model_path)

    # save model
    style_model.eval()
    style_model.cpu()
    save_model_filename = "Final_epoch_" + str(args.epochs) + "_" + \
        str(time.ctime()).replace(' ', '_') + "_" + str(
        args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(style_model.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Esempio n. 8
0
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        kwargs = {'num_workers': 0, 'pin_memory': False}
    else:
        kwargs = {}

    transform = transforms.Compose([transforms.Scale(args.image_size),
                                    transforms.CenterCrop(args.image_size),
                                    transforms.ToTensor(),
                                    transforms.Lambda(lambda x: x.mul(255))])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, **kwargs)

    transformer = TransformerNet()
    if (args.premodel != ""):
        transformer.load_state_dict(torch.load(args.premodel))
        print("load pretrain model:"+args.premodel)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16()
    utils.init_vgg16(args.vgg_model_dir)
    vgg.load_state_dict(torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight")))

    if args.cuda:
        transformer.cuda()
        vgg.cuda()

    style = utils.tensor_load_rgbimage(args.style_image, size=args.style_size)
    style = style.repeat(args.batch_size, 1, 1, 1)
    style = utils.preprocess_batch(style)
    if args.cuda:
        style = style.cuda()
    style_v = Variable(style, volatile=True)
    style_v = utils.subtract_imagenet_mean_batch(style_v)
    features_style = vgg(style_v)
    gram_style = [utils.gram_matrix(y) for y in features_style]


    hori=0 
    writer = SummaryWriter(args.logdir,comment=args.logdir)
    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        agg_cate_loss = 0.
        agg_cam_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()
            x = Variable(utils.preprocess_batch(x))
            if args.cuda:
                x = x.cuda()
            y = transformer(x)  
            xc = Variable(x.data.clone(), volatile=True)
            #print(y.size()) #(4L, 3L, 224L, 224L)

            
            # Calculate focus loss and category loss
            y_cam = utils.depreprocess_batch(y)
            y_cam = utils.subtract_mean_std_batch(y_cam) 
            
            xc_cam = utils.depreprocess_batch(xc)
            xc_cam = utils.subtract_mean_std_batch(xc_cam)
            

            del features_blobs[:]
            logit_x = net(xc_cam)
            logit_y = net(y_cam)
            
            label=[]
            cam_loss = 0
            for i in range(len(xc_cam)):
                h_x = F.softmax(logit_x[i])
                probs_x, idx_x = h_x.data.sort(0, True)
                label.append(idx_x[0])
                
                h_y = F.softmax(logit_y[i])
                probs_y, idx_y = h_y.data.sort(0, True)
                
                x_cam = returnCAM(features_blobs[0][i], weight_softmax, idx_x[0])
                x_cam = Variable(x_cam.data,requires_grad = False)
 
                y_cam = returnCAM(features_blobs[1][i], weight_softmax, idx_y[0])
                
                cam_loss += mse_loss(y_cam, x_cam)
            
            #the focus loss
            cam_loss *= 80
            #the category loss
            label = Variable(torch.LongTensor(label),requires_grad = False).cuda()
            cate_loss = 10000 * torch.nn.CrossEntropyLoss()(logit_y,label)
         
         

           
            y = utils.subtract_imagenet_mean_batch(y)
            xc = utils.subtract_imagenet_mean_batch(xc)

            features_y = vgg(y)
            features_xc = vgg(xc)

            #f_xc_c = Variable(features_xc[1].data, requires_grad=False)
            #content_loss = args.content_weight * mse_loss(features_y[1], f_xc_c)


            f_xc_c = Variable(features_xc[2].data, requires_grad=False)
            content_loss = args.content_weight * mse_loss(features_y[2], f_xc_c)
            style_loss = 0.
            for m in range(len(features_y)):
                gram_s = Variable(gram_style[m].data, requires_grad=False)
                gram_y = utils.gram_matrix(features_y[m])
                style_loss += args.style_weight * mse_loss(gram_y, gram_s[:n_batch, :, :])
            #add the total four loss and backward
            total_loss = style_loss + content_loss  + cam_loss + cate_loss
            total_loss.backward()
            optimizer.step()

            #something for display
            agg_content_loss += content_loss.data[0]
            agg_style_loss += style_loss.data[0]
            agg_cate_loss += cate_loss.data[0]
            agg_cam_loss += cam_loss.data[0]
            
            writer.add_scalar("Loss_Cont", agg_content_loss / (batch_id + 1), hori)
            writer.add_scalar("Loss_Style", agg_style_loss / (batch_id + 1), hori)
            writer.add_scalar("Loss_CAM", agg_cam_loss / (batch_id + 1), hori)
            writer.add_scalar("Loss_Cate", agg_cate_loss / (batch_id + 1), hori)
            hori += 1
            
            if (batch_id + 1) % args.log_interval == 0:
               mesg = "{}Epoch{}:[{}/{}] content:{:.2f} style:{:.2f} cate:{:.2f} cam:{:.2f}  total:{:.2f}".format(
                    time.strftime("%a %H:%M:%S"),e + 1, count, len(train_dataset),
                                 agg_content_loss / (batch_id + 1),
                                 agg_style_loss / (batch_id + 1),
                                 agg_cate_loss / (batch_id + 1),
                                 agg_cam_loss / (batch_id + 1),
                                 (agg_content_loss + agg_style_loss + agg_cate_loss + agg_cam_loss ) / (batch_id + 1)
               )
               print(mesg)
               
            if (batch_id + 1) % 2500 == 0:    
                transformer.eval()
                transformer.cpu()
                save_model_filename = "epoch_" + str(e+1) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
                    args.content_weight) + "_" + str(args.style_weight) + ".model"
                save_model_path = os.path.join(args.save_model_dir, save_model_filename)
                torch.save(transformer.state_dict(), save_model_path)
                transformer.cuda()
                transformer.train()
                print("saved at ",count)
    
    
    
    
    # save model
    transformer.eval()
    transformer.cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
        args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)
    
    writer.close()
    print("\nDone, trained model saved at", save_model_path)
Esempio n. 9
0
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    kwargs = {'num_workers': 0, 'pin_memory': False}

    if args.model_type == "rnn":
        transformer = transformer_net.TransformerRNN(args.pad_type)
        seq_size = 4
    else:
        transformer = transformer_net.TransformerNet(args.pad_type)
        seq_size = 2

    train_dataset = dataset.DAVISDataset(args.dataset,
                                         "train",
                                         seq_size=seq_size,
                                         interval=args.interval,
                                         no_flow=True)
    train_loader = DataLoader(train_dataset,
                              batch_size=1,
                              shuffle=True,
                              **kwargs)

    model_path = args.init_model
    print("=> Load from model file %s" % model_path)
    transformer.load_state_dict(torch.load(model_path))
    transformer.train()
    if args.model_type == "rnn":
        transformer.conv1 = transformer_net.ConvLayer(6,
                                                      32,
                                                      kernel_size=9,
                                                      stride=1,
                                                      pad_type=args.pad_type)
    optimizer = torch.optim.Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()
    l1_loss = torch.nn.L1Loss()

    vgg = Vgg16()
    vgg.load_state_dict(torch.load(os.path.join(args.vgg_model)))
    vgg.eval()

    transformer.cuda()
    vgg.cuda()
    mse_loss.cuda()

    style = utils.tensor_load_resize(args.style_image, args.style_size)
    style = style.unsqueeze(0)
    print("=> Style image size: " + str(style.size()))
    print("=> Pixel FDB loss weight: %f" % args.time_strength1)
    print("=> Feature FDB loss weight: %f" % args.time_strength2)

    style = utils.preprocess_batch(style).cuda()
    utils.tensor_save_bgrimage(
        style[0].detach(), os.path.join(args.save_model_dir,
                                        'train_style.jpg'), True)
    style = utils.subtract_imagenet_mean_batch(style)
    features_style = vgg(style)
    gram_style = [utils.gram_matrix(y).detach() for y in features_style]

    for e in range(args.epochs):
        agg_content_loss = agg_style_loss = agg_pixelfdb_loss = agg_featurefdb_loss = 0.
        iters = 0
        for batch_id, (x, flow, occ, _) in enumerate(train_loader):
            x = x[0]
            iters += 1

            optimizer.zero_grad()
            x = utils.preprocess_batch(x).cuda()
            y = transformer(x)  # (N, 3, 256, 256)

            if (batch_id + 1) % 100 == 0:
                idx = (batch_id + 1) // 100
                for i in range(args.batch_size):
                    utils.tensor_save_bgrimage(
                        y.data[i],
                        os.path.join(args.save_model_dir,
                                     "out_%02d_%02d.png" % (idx, i)), True)
                    utils.tensor_save_bgrimage(
                        x.data[i],
                        os.path.join(args.save_model_dir,
                                     "in_%02d-%02d.png" % (idx, i)), True)

            #xc = center_crop(x.detach(), y.shape[2], y.shape[3])

            y = utils.subtract_imagenet_mean_batch(y)
            x = utils.subtract_imagenet_mean_batch(x)

            features_y = vgg(y)
            features_xc = vgg(x)

            #content target
            f_xc_c = features_xc[2].detach()
            # content
            f_c = features_y[2]

            content_loss = args.content_weight * mse_loss(f_c, f_xc_c)

            style_loss = 0.
            for m in range(len(features_y)):
                gram_s = gram_style[m]
                gram_y = utils.gram_matrix(features_y[m])
                batch_style_loss = 0
                for n in range(gram_y.shape[0]):
                    batch_style_loss += args.style_weight * mse_loss(
                        gram_y[n], gram_s[0])
                style_loss += batch_style_loss / gram_y.shape[0]

            # FDB
            pixel_fdb_loss = args.time_strength1 * mse_loss(
                y[1:] - y[:-1], x[1:] - x[:-1])
            # temporal content: 16th
            feature_fdb_loss = args.time_strength2 * l1_loss(
                features_y[2][1:] - features_y[2][:-1],
                features_xc[2][1:] - features_xc[2][:-1])

            total_loss = content_loss + style_loss + pixel_fdb_loss + feature_fdb_loss

            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.data
            agg_style_loss += style_loss.data
            agg_pixelfdb_loss += pixel_fdb_loss.data
            agg_featurefdb_loss += feature_fdb_loss.data

            agg_total = agg_content_loss + agg_style_loss + agg_pixelfdb_loss + agg_featurefdb_loss
            mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\tpixel fdb: {:.6f}\tfeature fdb: {:.6f}\ttotal: {:.6f}".format(
                time.ctime(), e + 1, batch_id + 1, len(train_loader),
                agg_content_loss / iters, agg_style_loss / iters,
                agg_pixelfdb_loss / iters, agg_featurefdb_loss / iters,
                agg_total / iters)
            print(mesg)
            agg_content_loss = agg_style_loss = agg_pixelfdb_loss = agg_featurefdb_loss = 0.0
            iters = 0

        # save model
        save_model_filename = "epoch_" + str(e) + "_" + str(
            args.content_weight) + "_" + str(args.style_weight) + ".model"
        save_model_path = os.path.join(args.save_model_dir,
                                       save_model_filename)
        torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Esempio n. 10
0
def train(args):
    np.random.seed(args.seed)
    if args.cuda:
        ctx = mx.gpu(0)
    else:
        ctx = mx.cpu(0)
    # dataloader
    transform = utils.Compose([utils.Scale(args.image_size),
                               utils.CenterCrop(args.image_size),
                               utils.ToTensor(ctx),
                               ])
    train_dataset = data.ImageFolder(args.dataset, transform)
    train_loader = gluon.data.DataLoader(train_dataset, batch_size=args.batch_size, last_batch='discard')
    style_loader = utils.StyleLoader(args.style_folder, args.style_size, ctx=ctx)
    print('len(style_loader):',style_loader.size())
    # models
    vgg = net.Vgg16()
    utils.init_vgg_params(vgg, 'models', ctx=ctx)
    style_model = net.Net(ngf=args.ngf)
    style_model.initialize(init=mx.initializer.MSRAPrelu(), ctx=ctx)
    if args.resume is not None:
        print('Resuming, initializing using weight from {}.'.format(args.resume))
        style_model.collect_params().load(args.resume, ctx=ctx)
    print('style_model:',style_model)
    # optimizer and loss
    trainer = gluon.Trainer(style_model.collect_params(), 'adam',
                            {'learning_rate': args.lr})
    mse_loss = gluon.loss.L2Loss()

    for e in range(args.epochs):
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            # prepare data
            style_image = style_loader.get(batch_id)
            style_v = utils.subtract_imagenet_mean_preprocess_batch(style_image.copy())
            style_image = utils.preprocess_batch(style_image)

            features_style = vgg(style_v)
            gram_style = [net.gram_matrix(y) for y in features_style]

            xc = utils.subtract_imagenet_mean_preprocess_batch(x.copy())
            f_xc_c = vgg(xc)[1]
            with autograd.record():
                style_model.setTarget(style_image)
                y = style_model(x)

                y = utils.subtract_imagenet_mean_batch(y)
                features_y = vgg(y)

                content_loss = 2 * args.content_weight * mse_loss(features_y[1], f_xc_c)

                style_loss = 0.
                for m in range(len(features_y)):
                    gram_y = net.gram_matrix(features_y[m])
                    _, C, _ = gram_style[m].shape
                    gram_s = F.expand_dims(gram_style[m], 0).broadcast_to((args.batch_size, 1, C, C))
                    style_loss = style_loss + 2 * args.style_weight * mse_loss(gram_y, gram_s[:n_batch, :, :])

                total_loss = content_loss + style_loss
                total_loss.backward()
                
            trainer.step(args.batch_size)
            mx.nd.waitall()

            agg_content_loss += content_loss[0]
            agg_style_loss += style_loss[0]

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.3f}\tstyle: {:.3f}\ttotal: {:.3f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                                agg_content_loss.asnumpy()[0] / (batch_id + 1),
                                agg_style_loss.asnumpy()[0] / (batch_id + 1),
                                (agg_content_loss + agg_style_loss).asnumpy()[0] / (batch_id + 1)
                )
                print(mesg)

            
            if (batch_id + 1) % (4 * args.log_interval) == 0:
                # save model
                save_model_filename = "Epoch_" + str(e) + "iters_" + str(count) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
                    args.content_weight) + "_" + str(args.style_weight) + ".params"
                save_model_path = os.path.join(args.save_model_dir, save_model_filename)
                style_model.collect_params().save(save_model_path)
                print("\nCheckpoint, trained model saved at", save_model_path)

    # save model
    save_model_filename = "Final_epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
        args.content_weight) + "_" + str(args.style_weight) + ".params"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    style_model.collect_params().save(save_model_path)
    print("\nDone, trained model saved at", save_model_path)
Esempio n. 11
0
def train_fdb(args):
    transformer = transformer_net.TransformerNet(args.pad_type)
    train_dataset = dataset.DAVISDataset(args.dataset,
                                         seq_size=2,
                                         use_flow=args.flow)
    train_loader = DataLoader(train_dataset, batch_size=1)

    transformer.train()
    optimizer = torch.optim.Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16()
    vgg.load_state_dict(
        torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight")))
    vgg.eval()

    if args.cuda:
        transformer.cuda()
        vgg.cuda()
        mse_loss.cuda()

    style = utils.tensor_load_resize(args.style_image, args.style_size)
    style = style.unsqueeze(0)
    print("=> Style image size: " + str(style.size()))

    style = utils.preprocess_batch(style)
    if args.cuda: style = style.cuda()
    style = utils.subtract_imagenet_mean_batch(style)
    features_style = vgg(style)
    gram_style = [utils.gram_matrix(y).detach() for y in features_style]

    train_loader.dataset.reset()
    agg_content_loss = agg_style_loss = agg_pixelfdb_loss = agg_featurefdb_loss = 0.
    iters = 0
    elapsed_time = 0
    for batch_id, (x, flow, conf) in enumerate(tqdm(train_loader)):
        x = x[0]
        iters += 1

        optimizer.zero_grad()
        x = utils.preprocess_batch(x)  # (N, 3, 256, 256)
        if args.cuda: x = x.cuda()
        y = transformer(x)  # (N, 3, 256, 256)

        xc = center_crop(x.detach(), y.shape[2], y.shape[3])

        y = utils.subtract_imagenet_mean_batch(y)
        xc = utils.subtract_imagenet_mean_batch(xc)

        features_y = vgg(y)
        features_xc = vgg(xc)

        # FDB
        begin_time = time.time()
        pixel_fdb_loss = mse_loss(y[1:] - y[:-1], xc[1:] - xc[:-1])
        # temporal content: 16th
        feature_fdb_loss = mse_loss(features_y[2][1:] - features_y[2][:-1],
                                    features_xc[2][1:] - features_xc[2][:-1])
        pixel_fdb_loss.backward()
        elapsed_time += time.time() - begin_time

        if batch_id > 1000: break
    print(elapsed_time / float(batch_id + 1))
def train():
    if mx.context.num_gpus() > 0:
        ctx = mx.gpu()
    else:
        raise RuntimeError('There is no GPU device!')

    # loading configs
    args = Options().parse()
    cfg = Configs(args.config_path)
    # set logging level
    logging.basicConfig(level=logging.INFO)
    # set random seed
    np.random.seed(cfg.seed)

    # build dataset and loader
    content_dataset = ImageFolder(cfg.content_dataset, cfg.img_size, ctx=ctx)
    style_dataset = StyleLoader(cfg.style_dataset, cfg.style_size, ctx=ctx)
    content_loader = gluon.data.DataLoader(content_dataset, batch_size=cfg.batch_size, \
                                            last_batch='discard')

    vgg = Vgg16()
    vgg._init_weights(fixed=True, pretrain_path=cfg.vgg_check_point, ctx=ctx)

    style_model = Net(ngf=cfg.ngf)
    if cfg.resume is not None:
        print("Resuming from {} ...".format(cfg.resume))
        style_model.collect_params().load(cfg.resume, ctx=ctx)
    else:
        style_model.initialize(mx.initializer.MSRAPrelu(), ctx=ctx)
    print("Style model:")
    print(style_model)

    # build trainer
    lr_sche = mx.lr_scheduler.FactorScheduler(
        step=170000,
        factor=0.1,
        base_lr=cfg.base_lr
        #warmup_begin_lr=cfg.base_lr/3.0,
        #warmup_steps=300,
    )
    opt = mx.optimizer.Optimizer.create_optimizer('adam', lr_scheduler=lr_sche)
    trainer = gluon.Trainer(style_model.collect_params(), optimizer=opt)

    loss_fn = gluon.loss.L2Loss()

    logging.info("Start training with total {} epoch".format(cfg.total_epoch))
    iteration = 0
    total_time = 0.0
    num_batch = content_loader.__len__() * cfg.total_epoch
    for epoch in range(cfg.total_epoch):
        sum_content_loss = 0.0
        sum_style_loss = 0.0
        for batch_id, content_imgs in enumerate(content_loader):
            iteration += 1
            s = time.time()
            style_image = style_dataset.get(batch_id)

            style_vgg_input = subtract_imagenet_mean_preprocess_batch(
                style_image.copy())
            style_image = preprocess_batch(style_image)
            style_features = vgg(style_vgg_input)
            style_features = [
                style_model.gram.gram_matrix(mx.nd, f) for f in style_features
            ]

            content_vgg_input = subtract_imagenet_mean_preprocess_batch(
                content_imgs.copy())
            content_features = vgg(content_vgg_input)[1]

            with autograd.record():
                y = style_model(content_imgs, style_image)
                y = subtract_imagenet_mean_batch(y)
                y_features = vgg(y)

                content_loss = 2 * cfg.content_weight * loss_fn(
                    y_features[1], content_features)
                style_loss = 0.0
                for m in range(len(y_features)):
                    gram_y = style_model.gram.gram_matrix(mx.nd, y_features[m])
                    _, C, _ = style_features[m].shape
                    gram_s = mx.nd.expand_dims(style_features[m],
                                               0).broadcast_to((
                                                   gram_y.shape[0],
                                                   1,
                                                   C,
                                                   C,
                                               ))
                    style_loss = style_loss + 2 * cfg.style_weight * loss_fn(
                        gram_y, gram_s)
                total_loss = content_loss + style_loss
                total_loss.backward()

            trainer.step(cfg.batch_size)
            mx.nd.waitall()
            e = time.time()
            total_time += e - s
            sum_content_loss += content_loss[0]
            sum_style_loss += style_loss[0]
            if iteration % cfg.log_interval == 0:
                itera_sec = total_time / iteration
                eta_str = str(
                    datetime.timedelta(seconds=int((num_batch - iteration) *
                                                   itera_sec)))
                mesg = "{} Epoch [{}]:\t[{}/{}]\tTime:{:.2f}s\tETA:{}\tlr:{:.4f}\tcontent: {:.3f}\tstyle: {:.3f}\ttotal: {:.3f}".format(
                    time.strftime("%H:%M:%S",
                                  time.localtime()), epoch + 1, batch_id + 1,
                    content_loader.__len__(), itera_sec, eta_str,
                    trainer.optimizer.learning_rate,
                    sum_content_loss.asnumpy()[0] / (batch_id + 1),
                    sum_style_loss.asnumpy()[0] / (batch_id + 1),
                    (sum_content_loss + sum_style_loss).asnumpy()[0] /
                    (batch_id + 1))
                logging.info(mesg)
                ctx.empty_cache()
        save_model_filename = "Epoch_" + str(epoch + 1) +  "_" + str(time.ctime()).replace(' ', '_') + \
                "_" + str(cfg.content_weight) + "_" + str(cfg.style_weight) + ".params"
        if not os.path.isdir(cfg.save_model_dir):
            os.mkdir(cfg.save_model_dir)
        save_model_path = os.path.join(cfg.save_model_dir, save_model_filename)
        logging.info("Saving parameters to {}".format(save_model_path))
        style_model.collect_params().save(save_model_path)
Esempio n. 13
0
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        kwargs = {'num_workers': 0, 'pin_memory': False}
    else:
        kwargs = {}

    training_set = np.loadtxt(args.dataset, dtype=np.float32)
    training_set_size = training_set.shape[1]
    num_batch = int(training_set_size / args.batch_size)

    transformer = TransformerNet()
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16()
    utils.init_vgg16(args.vgg_model_dir)
    vgg.load_state_dict(
        torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight")))

    if args.cuda:
        transformer.cuda()
        vgg.cuda()

    style = np.loadtxt(args.style_image, dtype=np.float32)
    style = style.reshape((1, 1, args.style_size_x, args.style_size_y))
    style = torch.from_numpy(style)
    style = style.repeat(args.batch_size, 3, 1, 1)
    if args.cuda:
        style = style.cuda()
    style_v = Variable(style, volatile=True)
    style_v = utils.subtract_imagenet_mean_batch(style_v)
    features_style = vgg(style_v)
    gram_style = [utils.gram_matrix(y) for y in features_style]

    # Hard data
    if args.hard_data:
        hard_data = np.loadtxt(args.hard_data_file)
        # if not isinstance(hard_data[0], list):
        #     hard_data = [hard_data]

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        # for batch_id, (x, _) in enumerate(train_loader):
        for batch_id in range(num_batch):
            x = training_set[:, batch_id * args.batch_size:(batch_id + 1) *
                             args.batch_size]
            n_batch = x.shape[1]
            count += n_batch
            x = x.transpose()
            x = x.reshape((n_batch, 1, args.image_size_x, args.image_size_y))

            # plt.imshow(x[0,:,:,:].squeeze(0))
            # plt.show()
            x = torch.from_numpy(x).float()

            optimizer.zero_grad()

            x = Variable(x)
            if args.cuda:
                x = x.cuda()

            y = transformer(x)

            if args.hard_data:
                hard_data_loss = 0
                num_hard_data = 0
                for hd in hard_data:
                    hard_data_loss += args.hard_data_weight * (
                        y[:, 0, hd[1], hd[0]] -
                        hd[2] * 255.0).norm()**2 / n_batch
                    num_hard_data += 1
                hard_data_loss /= num_hard_data

            y = y.repeat(1, 3, 1, 1)
            # x = Variable(utils.preprocess_batch(x))

            # xc = x.data.clone()
            # xc = xc.repeat(1, 3, 1, 1)
            # xc = Variable(xc, volatile=True)

            y = utils.subtract_imagenet_mean_batch(y)
            # xc = utils.subtract_imagenet_mean_batch(xc)

            features_y = vgg(y)
            # features_xc = vgg(xc)

            # f_xc_c = Variable(features_xc[1].data, requires_grad=False)

            # content_loss = args.content_weight * mse_loss(features_y[1], f_xc_c)

            style_loss = 0.
            for m in range(len(features_y)):
                gram_s = Variable(gram_style[m].data, requires_grad=False)
                gram_y = utils.gram_matrix(features_y[m])
                style_loss += args.style_weight * mse_loss(
                    gram_y, gram_s[:n_batch, :, :])

            # total_loss = content_loss + style_loss

            total_loss = style_loss

            if args.hard_data:
                total_loss += hard_data_loss

            total_loss.backward()
            optimizer.step()

            # agg_content_loss += content_loss.data[0]
            agg_style_loss += style_loss.data[0]

            if (batch_id + 1) % args.log_interval == 0:
                if args.hard_data:
                    mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\thard_data: {:.6f}\ttotal: {:.6f}".format(
                        time.ctime(), e + 1, count, num_batch,
                        agg_content_loss / (batch_id + 1),
                        agg_style_loss / (batch_id + 1),
                        hard_data_loss.data[0],
                        (agg_content_loss + agg_style_loss) / (batch_id + 1))
                else:
                    mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                        time.ctime(), e + 1, count, num_batch,
                        agg_content_loss / (batch_id + 1),
                        agg_style_loss / (batch_id + 1),
                        (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)

    # save model
    transformer.eval()
    transformer.cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + "_" + str(
            args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Esempio n. 14
0
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    kwargs = {'num_workers': 0, 'pin_memory': False}

    transform = transforms.Compose([
        transforms.Resize((args.image_size, args.image_size)),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = dataset.CustomImageDataset(args.dataset,
                                               transform=transform,
                                               img_size=args.image_size)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              **kwargs)

    transformer = TransformerNet(args.pad_type)
    transformer = transformer.train()
    optimizer = torch.optim.Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()
    #print(transformer)
    vgg = Vgg16()
    vgg.load_state_dict(
        torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight")))
    vgg.eval()

    transformer = transformer.cuda()
    vgg = vgg.cuda()

    style = utils.tensor_load_resize(args.style_image, args.style_size)
    style = style.unsqueeze(0)
    print("=> Style image size: " + str(style.size()))

    #(1, H, W, C)
    style = utils.preprocess_batch(style).cuda()
    utils.tensor_save_bgrimage(
        style[0].detach(), os.path.join(args.save_model_dir,
                                        'train_style.jpg'), True)
    style = utils.subtract_imagenet_mean_batch(style)
    features_style = vgg(style)
    gram_style = [utils.gram_matrix(y).detach() for y in features_style]

    for e in range(args.epochs):
        train_loader.dataset.reset()
        agg_content_loss = 0.
        agg_style_loss = 0.
        iters = 0
        for batch_id, (x, _) in enumerate(train_loader):
            if x.size(0) != args.batch_size:
                print("=> Skip incomplete batch")
                continue
            iters += 1

            optimizer.zero_grad()
            x = utils.preprocess_batch(x).cuda()
            y = transformer(x)

            if (batch_id + 1) % 1000 == 0:
                idx = (batch_id + 1) // 1000
                utils.tensor_save_bgrimage(
                    y.data[0],
                    os.path.join(args.save_model_dir, "out_%d.png" % idx),
                    True)
                utils.tensor_save_bgrimage(
                    x.data[0],
                    os.path.join(args.save_model_dir, "in_%d.png" % idx), True)

            y = utils.subtract_imagenet_mean_batch(y)
            x = utils.subtract_imagenet_mean_batch(x)

            features_y = vgg(y)
            features_x = vgg(center_crop(x, y.size(2), y.size(3)))

            #content target
            f_x = features_x[2].detach()
            # content
            f_y = features_y[2]

            content_loss = args.content_weight * mse_loss(f_y, f_x)

            style_loss = 0.
            for m in range(len(features_y)):
                gram_s = gram_style[m]
                gram_y = utils.gram_matrix(features_y[m])
                batch_style_loss = 0
                for n in range(gram_y.shape[0]):
                    batch_style_loss += args.style_weight * mse_loss(
                        gram_y[n], gram_s[0])
                style_loss += batch_style_loss / gram_y.shape[0]

            total_loss = content_loss + style_loss

            total_loss.backward()
            optimizer.step()
            agg_content_loss += content_loss.data
            agg_style_loss += style_loss.data

            mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                time.ctime(), e + 1, batch_id + 1, len(train_loader),
                agg_content_loss / iters, agg_style_loss / iters,
                (agg_content_loss + agg_style_loss) / iters)
            print(mesg)
            agg_content_loss = agg_style_loss = 0.0
            iters = 0

        # save model
        save_model_filename = "epoch_" + str(e) + "_" + str(
            args.content_weight) + "_" + str(args.style_weight) + ".model"
        save_model_path = os.path.join(args.save_model_dir,
                                       save_model_filename)
        torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Esempio n. 15
0
def optimize(args):
    content_image = utils.tensor_load_grayimage(args.content_image, size=args.content_size)
    content_image = content_image.unsqueeze(0)
    content_image = Variable(content_image, requires_grad=False)
    content_image = utils.subtract_imagenet_mean_batch_gray(content_image)
    style_image = utils.tensor_load_rgbimage(args.style_image, size=args.style_size)
    style_image = style_image.unsqueeze(0)
    style_image = Variable(utils.preprocess_batch(style_image), requires_grad=False)
    style_image = utils.subtract_imagenet_mean_batch(style_image)

    # generate the vector field that we want to stylize
    # size = args.content_size
    # vectors = np.zeros((size, size, 2), dtype=np.float32)

    # vortex_spacing = 0.5
    # extra_factor = 2.
    #
    # a = np.array([1, 0]) * vortex_spacing
    # b = np.array([np.cos(np.pi / 3), np.sin(np.pi / 3)]) * vortex_spacing
    # rnv = int(2 * extra_factor / vortex_spacing)
    # vortices = [n * a + m * b for n in range(-rnv, rnv) for m in range(-rnv, rnv)]
    # vortices = [(x, y) for (x, y) in vortices if -extra_factor < x < extra_factor and -extra_factor < y < extra_factor]
    #
    # xs = np.linspace(-1, 1, size).astype(np.float32)[None, :]
    # ys = np.linspace(-1, 1, size).astype(np.float32)[:, None]
    #
    # for (x, y) in vortices:
    #     rsq = (xs - x) ** 2 + (ys - y) ** 2
    #     vectors[..., 0] += (ys - y) / rsq
    #     vectors[..., 1] += -(xs - x) / rsq
    #
    # for y in range(size):
    #     for x in range(size):
    #         angles[y, x] = math.atan(vectors[y, x, 1] / vectors[y, x, 0]) * 180 / math.pi

    # for y in range(size):
    #     for x in range(size):
    #         xx = float(x - size / 2)
    #         yy = float(y - size / 2)
    #         rsq = xx ** 2 + yy ** 2
    #         if (rsq == 0):
    #             vectors[y, x, 0] = 0
    #             vectors[y, x, 1] = 0
    #         else:
    #             vectors[y, x, 0] = -yy / rsq
    #             vectors[y, x, 1] = xx / rsq
    # f = h5py.File("../datasets/fake/vector_fields/cat_test3.h5", 'r')
    # a_group_key = list(f.keys())[0]
    # vectors = f[a_group_key][:]
    # vectors = utils.tensor_load_vector_field(vectors)
    # vectors = Variable(vectors, requires_grad=False)

    # load the pre-trained vgg-16 and extract features
    vgg = Vgg16()
    # utils.init_vgg16(args.vgg_model_dir)
    vgg.load_state_dict(torch.load(os.path.join(args.vgg_model_dir, 'vgg16.weight')))
    if args.cuda:
        style_image = style_image.cuda()
        vgg.cuda()
    features_style = vgg(style_image)
    gram_style = [utils.gram_matrix(y) for y in features_style]

    # load the transformer net and extract features
    transformer_phi1 = TransformerNet()
    transformer_phi1.load_state_dict(torch.load(args.transformer_model_phi1_path))
    if args.cuda:
        # vectors = vectors.cuda()
        content_image = content_image.cuda()
        transformer_phi1.cuda()
    vectors = transformer_phi1(content_image)
    vectors = Variable(vectors.data, requires_grad=False)

    # init optimizer
    content_image_size = content_image.data.size()
    output_size = np.asarray(content_image_size)
    output_size[1] = 3
    output_size = torch.Size(output_size)
    output = Variable(torch.randn(output_size, device="cuda"), requires_grad=True)
    optimizer = Adam([output], lr=args.lr)
    mse_loss = torch.nn.MSELoss()
    cosine_loss = torch.nn.CosineEmbeddingLoss()
    # label = torch.ones(1, 1, args.content_size, args.content_size)
    label = torch.ones(1, 128, 128, 128)
    if args.cuda:
        label = label.cuda()

    # optimize the images
    transformer_phi2 = TransformerNet()
    transformer_phi2.load_state_dict(torch.load(args.transformer_model_phi2_path))
    if args.cuda:
        transformer_phi2.cuda()
    tbar = trange(args.iters)
    for e in tbar:
        utils.imagenet_clamp_batch(output, 0, 255)
        optimizer.zero_grad()
        transformer_input = utils.gray_bgr_batch(output)
        transformer_y = transformer_phi2(transformer_input)
        content_loss = args.content_weight * cosine_loss(vectors, transformer_y, label)
        # content_loss = args.content_weight * mse_loss(vectors, transformer_y)

        vgg_input = output
        features_y = vgg(vgg_input)
        style_loss = 0
        for m in range(len(features_y)):
            gram_y = utils.gram_matrix(features_y[m])
            gram_s = Variable(gram_style[m].data, requires_grad=False)
            style_loss += args.style_weight * mse_loss(gram_y, gram_s)

        total_loss = content_loss + style_loss
        # total_loss = content_loss
        total_loss.backward()
        optimizer.step()
        tbar.set_description(str(total_loss.data.cpu().numpy().item()))
        if ((e+1) % args.log_interval == 0):
            print("iter: %d content_loss: %f style_loss %f" % (e, content_loss.item(), style_loss.item()))

    # save the image
    output = utils.add_imagenet_mean_batch_device(output, args.cuda)
    utils.tensor_save_bgrimage(output.data[0], args.output_image, args.cuda)
Esempio n. 16
0
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        kwargs = {'num_workers': 0, 'pin_memory': False}
    else:
        kwargs = {}

    transform = transforms.Compose([
        transforms.Scale(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              **kwargs)

    transformer = TransformerNet()
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16()
    utils.init_vgg16(args.vgg_model_dir)
    vgg.load_state_dict(
        torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight")))

    if args.cuda:
        transformer.cuda()
        vgg.cuda()

    style = utils.tensor_load_rgbimage(args.style_image, size=args.style_size)
    style = style.repeat(args.batch_size, 1, 1, 1)
    style = utils.preprocess_batch(style)
    if args.cuda:
        style = style.cuda()
    style_v = Variable(style, volatile=True)
    utils.subtract_imagenet_mean_batch(style_v)
    features_style = vgg(style_v)
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()
            x = Variable(utils.preprocess_batch(x))
            if args.cuda:
                x = x.cuda()

            y = transformer(x)

            xc = Variable(x.data.clone(), volatile=True)

            utils.subtract_imagenet_mean_batch(y)
            utils.subtract_imagenet_mean_batch(xc)

            features_y = vgg(y)
            features_xc = vgg(xc)

            f_xc_c = Variable(features_xc[1].data, requires_grad=False)

            content_loss = args.content_weight * mse_loss(
                features_y[1], f_xc_c)

            style_loss = 0.
            for m in range(len(features_y)):
                gram_s = Variable(gram_style[m].data, requires_grad=False)
                gram_y = utils.gram_matrix(features_y[m])
                style_loss += args.style_weight * mse_loss(
                    gram_y, gram_s[:n_batch, :, :])

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.data[0]
            agg_style_loss += style_loss.data[0]

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)

    # save model
    transformer.eval()
    transformer.cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + "_" + str(
            args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Esempio n. 17
0
def train():
    check_point_path = ''

    transform = transforms.Compose([transforms.Scale(IMAGE_SIZE),
                                    transforms.CenterCrop(IMAGE_SIZE),
                                    transforms.ToTensor(),
                                    transforms.Lambda(lambda x: x.mul(255))])

    train_dataset = datasets.ImageFolder(DATASET_FOLDER, transform)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE)

    style_model = Net(ngf=FILTER_CHANNEL, dv=device).to(device)
    if RESUME is not None:
        print('Resuming, initializing using weight from {}.'.format(RESUME))
        style_model.load_state_dict(torch.load(RESUME))
    print(style_model)
    optimizer = Adam(style_model.parameters(), LEARNING_RATE)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16()
    utils.init_vgg16(VGG_DIR)
    vgg.load_state_dict(torch.load(os.path.join(VGG_DIR, "vgg16.weight")))
    vgg.to(device)

    style_loader = utils.StyleLoader(STYLE_FOLDER, IMAGE_SIZE, device)
    
    tbar = tqdm(range(EPOCHS))
    for e in tbar:
        style_model.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()
            x = Variable(utils.preprocess_batch(x)).to(device)

            style_v = style_loader.get(batch_id)
            style_model.setTarget(style_v)

            style_v = utils.subtract_imagenet_mean_batch(style_v, device)
            features_style = vgg(style_v)
            gram_style = [utils.gram_matrix(y) for y in features_style]

            y = style_model(x)
            xc = Variable(x.data.clone())

            y = utils.subtract_imagenet_mean_batch(y, device)
            xc = utils.subtract_imagenet_mean_batch(xc, device)

            features_y = vgg(y)
            features_xc = vgg(xc)

            f_xc_c = Variable(features_xc[1].data, requires_grad=False)

            content_loss = CONT_WEIGHT * mse_loss(features_y[1], f_xc_c)

            style_loss = 0.
            for m in range(len(features_y)):
                gram_y = utils.gram_matrix(features_y[m])
                gram_s = Variable(gram_style[m].data, requires_grad=False).repeat(BATCH_SIZE, 1, 1, 1)
                style_loss += STYLE_WEIGHT * mse_loss(gram_y.unsqueeze_(1), gram_s[:n_batch, :, :])

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.data[0]
            agg_style_loss += style_loss.data[0]

            if (batch_id + 1) % 100 == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                                agg_content_loss / (batch_id + 1),
                                agg_style_loss / (batch_id + 1),
                                (agg_content_loss + agg_style_loss) / (batch_id + 1)
                )
                tbar.set_description(mesg)

            
            if (batch_id + 1) % (4 * 100) == 0:
                # save model
                style_model.eval()
                style_model.cpu()
                save_model_filename = "Epoch_" + str(e) + "iters_" + str(count) + "_" +                     str(time.ctime()).replace(' ', '_') + "_" + str(
                    CONT_WEIGHT) + "_" + str(STYLE_WEIGHT) + ".model"
                save_model_path = os.path.join(SAVE_MODEL_DIR, save_model_filename)
                torch.save(style_model.state_dict(), save_model_path)
                if check_point_path:
                    os.remove(check_point_path)
                check_point_path = save_model_path
                style_model.train()
                style_model.cuda()
                tbar.set_description("\nCheckpoint, trained model saved at", save_model_path)

    # save model
    style_model.eval()
    style_model.cpu()
    save_model_filename = "Final_epoch_" + str(EPOCHS) + "_" +         str(time.ctime()).replace(' ', '_') + "_" + str(
        CONT_WEIGHT) + "_" + str(STYLE_WEIGHT) + ".model"
    save_model_path = os.path.join(SAVE_MODEL_DIR, save_model_filename)
    torch.save(style_model.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Esempio n. 18
0
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        kwargs = {'num_workers': 0, 'pin_memory': False}
    else:
        kwargs = {}

    if args.model_type == "rnn":
        transformer = transformer_net.TransformerRNN(args.pad_type)
        seq_size = 4
    else:
        transformer = transformer_net.TransformerNet(args.pad_type)
        seq_size = 2

    train_dataset = dataset.DAVISDataset(args.dataset,
                                         seq_size=seq_size,
                                         use_flow=args.flow)
    train_loader = DataLoader(train_dataset, batch_size=1, **kwargs)

    if args.model_type == "rnn":
        transformer = transformer_net.TransformerRNN(args.pad_type)
    else:
        transformer = transformer_net.TransformerNet(args.pad_type)
    model_path = args.init_model
    print("=> Load from model file %s" % model_path)
    transformer.load_state_dict(torch.load(model_path))
    transformer.train()
    if args.model_type == "rnn":
        transformer.conv1 = transformer_net.ConvLayer(6,
                                                      32,
                                                      kernel_size=9,
                                                      stride=1,
                                                      pad_type=args.pad_type)
    optimizer = torch.optim.Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()
    l1_loss = torch.nn.SmoothL1Loss()

    vgg = Vgg16()
    vgg.load_state_dict(
        torch.load(os.path.join(args.vgg_model, "vgg16.weight")))
    vgg.eval()

    if args.cuda:
        transformer.cuda()
        vgg.cuda()
        mse_loss.cuda()
        l1_loss.cuda()

    style = utils.tensor_load_resize(args.style_image, args.style_size)
    style = style.unsqueeze(0)
    print("=> Style image size: " + str(style.size()))
    print("=> Pixel OFB loss weight: %f" % args.time_strength)

    style = utils.preprocess_batch(style)
    if args.cuda: style = style.cuda()
    utils.tensor_save_bgrimage(
        style[0].detach(), os.path.join(args.save_model_dir,
                                        'train_style.jpg'), args.cuda)
    style = utils.subtract_imagenet_mean_batch(style)
    features_style = vgg(style)
    gram_style = [utils.gram_matrix(y).detach() for y in features_style]

    for e in range(args.epochs):
        train_loader.dataset.reset()
        transformer.train()
        transformer.cuda()
        agg_content_loss = agg_style_loss = agg_pixelofb_loss = 0.
        iters = 0
        anormaly = False
        for batch_id, (x, flow, conf) in enumerate(train_loader):
            x, flow, conf = x[0], flow[0], conf[0]
            iters += 1

            optimizer.zero_grad()
            x = utils.preprocess_batch(x)  # (N, 3, 256, 256)
            if args.cuda:
                x = x.cuda()
                flow = flow.cuda()
                conf = conf.cuda()
            y = transformer(x)  # (N, 3, 256, 256)

            xc = center_crop(x.detach(), y.size(2), y.size(3))

            vgg_y = utils.subtract_imagenet_mean_batch(y)
            vgg_x = utils.subtract_imagenet_mean_batch(xc)

            features_y = vgg(vgg_y)
            features_xc = vgg(vgg_x)

            #content target
            f_xc_c = features_xc[2].detach()
            # content
            f_c = features_y[2]

            #content_feature_target = center_crop(f_xc_c, f_c.size(2), f_c.size(3))
            content_loss = args.content_weight * mse_loss(f_c, f_xc_c)

            style_loss = 0.
            for m in range(len(features_y)):
                gram_s = gram_style[m]
                gram_y = utils.gram_matrix(features_y[m])
                batch_style_loss = 0
                for n in range(gram_y.shape[0]):
                    batch_style_loss += args.style_weight * mse_loss(
                        gram_y[n], gram_s[0])
                style_loss += batch_style_loss / gram_y.shape[0]

            warped_y, warped_y_mask = warp(y[1:], flow)
            warped_y = warped_y.detach()
            warped_y_mask *= conf
            pixel_ofb_loss = args.time_strength * weighted_mse(
                y[:-1], warped_y, warped_y_mask)

            total_loss = content_loss + style_loss + pixel_ofb_loss

            total_loss.backward()
            optimizer.step()

            if (batch_id + 1) % 100 == 0:
                prefix = args.save_model_dir + "/"
                idx = (batch_id + 1) // 100
                flow_image = flow_to_color(
                    flow[0].detach().cpu().numpy().transpose(1, 2, 0))
                utils.save_image(prefix + "forward_flow_%d.png" % idx,
                                 flow_image)
                warped_x, warped_x_mask = warp(x[1:], flow)
                warped_x = warped_x.detach()
                warped_x_mask *= conf
                for i in range(2):
                    utils.tensor_save_bgrimage(
                        y.data[i], prefix + "out_%d-%d.png" % (idx, i),
                        args.cuda)
                    utils.tensor_save_bgrimage(
                        x.data[i], prefix + "in_%d-%d.png" % (idx, i),
                        args.cuda)
                    if i < warped_y.shape[0]:
                        utils.tensor_save_bgrimage(
                            warped_y.data[i],
                            prefix + "wout_%d-%d.png" % (idx, i), args.cuda)
                        utils.tensor_save_bgrimage(
                            warped_x.data[i],
                            prefix + "win_%d-%d.png" % (idx, i), args.cuda)
                        utils.tensor_save_image(
                            prefix + "conf_%d-%d.png" % (idx, i),
                            warped_x_mask.data[i])

            agg_content_loss += content_loss.data
            agg_style_loss += style_loss.data
            agg_pixelofb_loss += pixel_ofb_loss.data

            agg_total = agg_content_loss + agg_style_loss + agg_pixelofb_loss
            mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\tpixel ofb: {:.6f}\ttotal: {:.6f}".format(
                time.ctime(), e + 1, batch_id + 1, len(train_loader),
                agg_content_loss / iters, agg_style_loss / iters,
                agg_pixelofb_loss / iters, agg_total / iters)
            print(mesg)
            agg_content_loss = agg_style_loss = agg_pixelofb_loss = 0.0
            iters = 0

        # save model
        transformer.eval()
        transformer.cpu()
        save_model_filename = "epoch_" + str(e) + "_" + str(
            args.content_weight) + "_" + str(args.style_weight) + ".model"
        save_model_path = os.path.join(args.save_model_dir,
                                       save_model_filename)
        torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Esempio n. 19
0
def train(args):
    serialNumFile = "serialNum.txt"
    serial = 0
    if os.path.isfile(serialNumFile):
        with open(serialNumFile, "r") as t:
            serial = int(t.read())

    serial += 1
    with open(serialNumFile, "w") as t:
        t.write(str(serial))

    if args.mysql:
        cnx = mysql.connector.connect(user='******',
                                      database='midburn',
                                      password='******')
        cursor = cnx.cursor()
    location = args.dataset.split("/")
    if location[-1] == "":
        location = location[-2]
    else:
        location = location[-1]
    save_model_filename = str(serial) + "_" + extractName(
        args.style_image) + "_" + str(args.epochs) + "_" + str(
            int(args.content_weight)) + "_" + str(int(
                args.style_weight)) + "_size_" + str(
                    args.image_size) + "_dataset_" + str(location) + ".model"
    print(save_model_filename)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    m_epoch = 0
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        #kwargs = {'num_workers': 0, 'pin_memory': False}
        kwargs = {'num_workers': 4, 'pin_memory': True}
    else:
        kwargs = {}

    transform = transforms.Compose([
        transforms.Scale(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              **kwargs)

    transformer = TransformerNet()
    #transformer = ResNeXtNet()
    transformer_type = transformer.__class__.__name__
    optimizer = Adam(transformer.parameters(), args.lr)
    if args.l1:
        loss_criterion = torch.nn.L1Loss()
    else:
        loss_criterion = torch.nn.MSELoss()
    loss_type = loss_criterion.__class__.__name__

    if args.visdom:
        vis = VisdomLinePlotter("Style Transfer: " + transformer_type)
    else:
        vis = None

    vgg = Vgg16()
    utils.init_vgg16(args.vgg_model_dir)
    vgg.load_state_dict(
        torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight")))

    if args.cuda:
        transformer.cuda()
        vgg.cuda()

    if args.model is not None:
        transformer.load_state_dict(torch.load(args.model))
        save_model_filename = save_model_filename + "@@@@@@" + str(
            int(getEpoch(args.model)) + int(args.epochs))
        m_epoch += int(getEpoch(args.model))
        print("loaded model\n")

    for param in vgg.parameters():
        param.requires_grad = False

    with torch.no_grad():
        style = utils.tensor_load_rgbimage(args.style_image,
                                           size=args.style_size)
        style = style.repeat(args.batch_size, 1, 1, 1)
        style = utils.preprocess_batch(style)
        if args.cuda:
            style = style.cuda()

        style = utils.subtract_imagenet_mean_batch(style)
        features_style = vgg(style)
        gram_style = [utils.gram_matrix(y) for y in features_style]
        del features_style
        del style

    # TODO: scheduler and style-loss criterion unused at the moment
    scheduler = StepLR(optimizer, step_size=15000 // args.batch_size)
    style_loss_criterion = torch.nn.CosineSimilarity()
    total_count = 0

    if args.mysql:
        q1 = ("REPLACE INTO `images`(`name`) VALUES ('" + args.style_image +
              "')")
        cursor.execute(q1)
        cnx.commit()
        imgId = cursor.lastrowid

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0

        for batch_id, (x, _) in enumerate(train_loader):

            n_batch = len(x)
            count += n_batch
            total_count += n_batch
            optimizer.zero_grad()
            x = utils.preprocess_batch(x)
            if args.cuda:
                x = x.cuda()

            y = transformer(x)

            y = utils.subtract_imagenet_mean_batch(y)
            xc = utils.subtract_imagenet_mean_batch(x)

            features_y = vgg(y)
            f_xc_c = vgg.content_features(xc)

            content_loss = args.content_weight * loss_criterion(
                features_y[1], f_xc_c)

            style_loss = 0.
            for m in range(len(features_y)):
                gram_s = gram_style[m]
                gram_y = utils.gram_matrix(features_y[m])
                style_loss += loss_criterion(gram_y, gram_s[:n_batch, :, :])
                #style_loss -= style_loss_criterion(gram_y, gram_s[:n_batch, :, :])

            style_loss *= args.style_weight
            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()
            # TODO: enable
            #scheduler.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            if (batch_id + 1) % args.log_interval == 0:
                if args.mysql:
                    q1 = (
                        "REPLACE INTO `statistics`(`imgId`,`epoch`, `iteration_id`, `content_loss`, `style_loss`, `loss`) VALUES ("
                        + str(imgId) + "," + str(int(e) + m_epoch) + "," +
                        str(batch_id) + "," + str(agg_content_loss /
                                                  (batch_id + 1)) + "," +
                        str(agg_style_loss / (batch_id + 1)) + "," + str(
                            (agg_content_loss + agg_style_loss) /
                            (batch_id + 1)) + ")")
                    cursor.execute(q1)
                    cnx.commit()
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}\n".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                sys.stdout.flush()
                print(mesg)
            if vis is not None:
                vis.plot(loss_type, "Content Loss", total_count,
                         content_loss.item())
                vis.plot(loss_type, "Style Loss", total_count,
                         style_loss.item())
                vis.plot(loss_type, "Total Loss", total_count,
                         total_loss.item())

    # save model
    transformer.eval()
    transformer.cpu()

    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)