Example #1
0
def main():
    test_dir = opt.test_dir
    feature_param_file = opt.feat
    class_param_file = opt.cls
    bsize = opt.b

    # models
    if 'vgg' == opt.i:
        feature = Vgg16()
    elif 'resnet' == opt.i:
        feature = resnet50()
    elif 'densenet' == opt.i:
        feature = densenet121()
    feature.cuda()
    # feature.load_state_dict(torch.load(feature_param_file))
    feature.eval()

    classifier = Classifier(opt.i)
    classifier.cuda()
    # classifier.load_state_dict(torch.load(class_param_file))
    classifier.eval()

    loader = torch.utils.data.DataLoader(MyClsTestData(test_dir,
                                                       transform=True),
                                         batch_size=bsize,
                                         shuffle=True,
                                         num_workers=4,
                                         pin_memory=True)
    acc = eval_acc(feature, classifier, loader)
    print acc
Example #2
0
 def __init__(self, vgg_path):
     super().__init__()
     self.vgg = Vgg16(vgg_path)
     self.loss_mse = nn.MSELoss()
     self.style_weight = 1e5
     self.content_weight = 1e0
     self.tv_weight = 1e-7
 def __init__(self, loss, gpu=0, p_layer=14):
     super(PerceptualLoss, self).__init__()
     self.criterion = loss
     self.device = torch.device('cuda') if gpu else torch.device('cpu')
     cnn = py_models.vgg19(pretrained=True).features
     cnn = cnn.cuda()
     model = nn.Sequential()
     model = model.cuda()
     for i, layer in enumerate(list(cnn)):
         model.add_module(str(i), layer)
         if i == p_layer:
             break
     self.contentFunc = model
     self.styleFunc = Vgg16(requires_grad=False).to(self.device)
Example #4
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain

        # load/define networks
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids)
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf,
                                          opt.which_model_netD,
                                          opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            self.load_network(self.netG, 'G', opt.which_epoch)
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch)

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionL1 = torch.nn.L1Loss()

	    # load vgg network
	    dtype = torch.cuda.FloatTensor
	    self.vgg = Vgg16().type(dtype)

            # initialize optimizers
            self.schedulers = []
            self.optimizers = []
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG)
        if self.isTrain:
            networks.print_network(self.netD)
        print('-----------------------------------------------')
 def __init__(self, args):
     ## Args
     self._train_dataset_path = args.train_dataset_path
     self._epochs = args.epochs
     self._batch_iterations = args.batch_iterations
     self._batch_size = args.batch_size
     self._learning_rate = args.learning_rate
     self._content_weight = args.content_weight
     self._style_weight = args.style_weight
     self._style_image_path = args.style_image_path
     self._log_interval = args.log_interval
     self._save_interval = args.save_interval
     self._trained_models_dir = args.trained_models_dir
     
     ## GPU if available, CPU otherwise
     self._device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
     self._train_loader = self._get_train_loader()
     self._transfer_net = TransferNet()
     self._vgg = Vgg16().to(self._device)
     self._gram_style = self._compute_gram_from_style()
Example #6
0
def train(restore_path=None):
    epochs = 1
    lr = 1e-3
    content_weight = 1e5
    style_weight = 1e10
    batch_size = 4

    # not downloaded to Git because it is too big, download from COCO website
    data_path = 'COCO/'

    global mean
    global std
    mean = torch.Tensor(mean).reshape(1, -1, 1, 1).to(device)
    std = torch.Tensor(std).reshape(1, -1, 1, 1).to(device)

    transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((256, 256)),
        torchvision.transforms.CenterCrop((256, 256)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = torchvision.datasets.ImageFolder(root=data_path,
                                                     transform=transform)

    # load data after transforming to correct size and pixel values
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               num_workers=16,
                                               shuffle=True,
                                               pin_memory=True)

    gen = Generator().to(device)
    if restore_path is not None:
        gen.load_state_dict(torch.load(restore_path))
    optimizer = torch.optim.Adam(gen.parameters(), lr=lr)
    criterion = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)

    # compute style features and gram matrix
    style_transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((256, 256)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Lambda(lambda x: x.mul(255))
    ])

    # open style image and turn into batch-sized tensor
    style = Image.open("style_images/rickmorty-style.jpg")
    style = style_transform(style)
    style = style.to(device)
    style = normalize(style)
    style = style.repeat(batch_size, 1, 1, 1)
    features_style = vgg(style)

    # compute gram of vgg output of style image tensor
    gr_norm = [gram(ft) for ft in features_style]

    # start train loop
    for e in tqdm(range(epochs)):
        loss = 0
        agg_content_loss = 0
        agg_style_loss = 0
        avg_time = 0
        for idx, (example_data, _) in enumerate(train_loader):
            start = time.time()
            optimizer.zero_grad()
            example_data = example_data.to(device)

            # pass the output through the generator and normalize
            output = gen(example_data)
            output = normalize(output)

            # normalize the original data
            example_data = normalize(example_data)

            # pass the output and the original data through vgg
            features_output = vgg(output)
            features_content = vgg(example_data)

            # calculate content and style loss as described in the paper
            content_loss = content_weight * criterion(features_output.relu2_2,
                                                      features_content.relu2_2)
            style_loss = 0
            for ft, gr_style in zip(features_output, gr_norm):
                gr = gram(ft)
                style_loss += criterion(gr, gr_style)
            style_loss = style_weight * style_loss

            # propagate the loss
            loss = content_loss + style_loss
            loss.backward()
            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()
            if idx % 500 == 0:
                mesg = "Epoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    e + 1, (idx + 1) * 4, len(train_dataset),
                    agg_content_loss / (idx + 1), agg_style_loss / (idx + 1),
                    (agg_content_loss + agg_style_loss) / (idx + 1))
                tqdm.write(mesg)
            if idx % 2000 == 1999:
                torch.save(gen.state_dict(), f'models/gen_ep{e}_b{idx}.pt')
            optimizer.step()

    # save the final model
    torch.save(gen.state_dict(), f'gen.pt')
Example #7
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain
        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        self.visual_names = ['real_A', 'fake_B', 'real_B']
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        if self.isTrain:
            self.model_names = ['G', 'D']
        else:  # during test time, only load Gs
            self.model_names = ['G']
# use_gan
        self.use_gan = opt.use_GAN
        self.w_vgg = opt.w_vgg
        self.w_tv = opt.w_tv
        self.w_gan = opt.w_gan
        self.use_condition = opt.use_condition
        # load/define networks
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt.norm,
                                      not opt.no_dropout, opt.init_type,
                                      self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            if self.use_condition == 1:
                self.netD = networks.define_D(opt.input_nc + opt.output_nc,
                                              opt.ndf, opt.which_model_netD,
                                              opt.n_layers_D, opt.norm,
                                              use_sigmoid, opt.init_type,
                                              self.gpu_ids)
            else:
                self.netD = networks.define_D(opt.input_nc, opt.ndf,
                                              opt.which_model_netD,
                                              opt.n_layers_D, opt.norm,
                                              use_sigmoid, opt.init_type,
                                              self.gpu_ids)

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            # define loss functions
            if opt.which_model_netD == 'multi':
                self.criterionGAN = networks.GANLoss_multi(
                    use_lsgan=not opt.no_lsgan).to(self.device)
            else:
                self.criterionGAN = networks.GANLoss(
                    use_lsgan=not opt.no_lsgan).to(self.device)
            self.criterionL1 = torch.nn.L1Loss()

            # load vgg network
            self.vgg = Vgg16().type(torch.cuda.FloatTensor)

            # initialize optimizers
            self.optimizers = []
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
Example #8
0
	def __init__(self, classi = False):
		super(VGGFeature, self).__init__()
		self.add_module('vgg', Vgg16(classi = classi))
Example #9
0
def train_GAN(use_cuda=False, numb_style_images=100):
    path = "/data/" if use_cuda else "/home/dobosevych/Documents/Cats/"
    train_loader = load_data(path, upper_bound=21000)
    test_loader = load_data(path, lower_bound=21000, upper_bound=22000)

    lr = 0.0002
    betas = (0.5, 0.999)
    discriminator = Discriminator()
    generator = Generator()
    vgg = Vgg16(requires_grad=False)
    if use_cuda:
        vgg.cuda()
    styles = get_gram_matrices(next(iter(train_loader)))

    if use_cuda:
        discriminator = discriminator.cuda()
        generator = generator.cuda()

    d_optimizer = Adam(discriminator.parameters(), lr=lr, betas=betas)
    g_optimizer = Adam(generator.parameters(), lr=lr, betas=betas)
    criterion_BCE = nn.BCELoss()
    criterion_MSE = nn.MSELoss()

    num_epochs = 20
    num_of_samples = 100

    for epoch in range(num_epochs):
        for i, (color_images, b_and_w_images) in enumerate(train_loader):
            minibatch = color_images.size(0)

            # damaged = make_damaged(images)
            # damaged = Variable(damaged)
            color_images = Variable(color_images)
            b_and_w_images = Variable(b_and_w_images)
            labels_1 = Variable(torch.ones(minibatch))
            labels_0 = Variable(torch.zeros(minibatch))

            if use_cuda:
                color_images, b_and_w_images, labels_0, labels_1 = color_images.cuda(
                ), b_and_w_images.cuda(), labels_0.cuda(), labels_1.cuda(
                )  #, damaged.cuda()

            # Generator training
            generated_images = generator(b_and_w_images)
            out = discriminator(generated_images)

            styleloss = 0

            for style_img in styles:
                styleloss += style_loss(style_img, generated_images, vgg,
                                        minibatch)

            # loss_img = criterion_MSE(generated_images, color_images)
            loss_1 = criterion_BCE(out, labels_1)
            # g_loss = 100 * loss_img + loss_1
            g_loss = 100 * styleloss + loss_1
            g_loss.backward()
            g_optimizer.step()

            # Discriminator training
            generated_images = generator(b_and_w_images)
            discriminator.zero_grad()
            out_0 = discriminator(generated_images)
            loss_0 = criterion_BCE(out_0, labels_0)

            out_1 = discriminator(color_images)
            loss_1 = criterion_BCE(out_1, labels_1)

            d_loss = loss_0 + loss_1
            d_loss.backward()
            d_optimizer.step()

            print("Epoch: [{}/{}], Step: [{}/{}]".format(
                epoch + 1, num_epochs, i + 1, len(train_loader)))

        test_images_color, test_images_bw = next(iter(test_loader))
        test_images_bw = Variable(test_images_bw)

        if use_cuda:
            test_images_bw = test_images_bw.cuda()

        test_images_colored = generator(test_images_bw)
        test_images_colored = test_images_colored.view(num_of_samples, 3, 128,
                                                       128).data.cpu().numpy()
        filename_colored = "/output/epoch_{}/colored/sample" if use_cuda else "samples/epoch_{}/colored/sample"
        filename_bw = "/output/epoch_{}/black_and_white/sample" if use_cuda else "samples/epoch_{}/black_and_white/sample"
        filename_color = "/output/epoch_{}/incolor/sample" if use_cuda else "samples/epoch_{}/incolor/sample"

        save_images(test_images_colored,
                    filename=filename_colored.format(epoch + 1),
                    width=10,
                    size=(3, 128, 128))
        save_images(test_images_bw,
                    filename=filename_bw.format(epoch + 1),
                    width=10,
                    size=(3, 128, 128))
        save_images(test_images_color,
                    filename=filename_color.format(epoch + 1),
                    width=10,
                    size=(3, 128, 128))
Example #10
0
def train(args):
    # GPU enabling
    if (args.gpu != None):
        use_cuda = True
        dtype = torch.cuda.FloatTensor
        torch.cuda.set_device(args.gpu)
        print("Current device: %d" % torch.cuda.current_device())

    # visualization of training controlled by flag
    visualize = (args.visualize != None)
    if (visualize):
        img_transform_512 = transforms.Compose([
            # scale shortest side to image_size
            transforms.Scale(512),
            # crop center image_size out
            transforms.CenterCrop(512),
            # turn image from [0-255] to [0-1]
            transforms.ToTensor(),
            utils.normalize_tensor_transform(
            )  # normalize with ImageNet values
        ])

        testImage_amber = utils.load_image("content_imgs/amber.jpg")
        testImage_amber = img_transform_512(testImage_amber)
        testImage_amber = Variable(testImage_amber.repeat(1, 1, 1, 1),
                                   requires_grad=False).type(dtype)

        testImage_dan = utils.load_image("content_imgs/dan.jpg")
        testImage_dan = img_transform_512(testImage_dan)
        testImage_dan = Variable(testImage_dan.repeat(1, 1, 1, 1),
                                 requires_grad=False).type(dtype)

        testImage_maine = utils.load_image("content_imgs/maine.jpg")
        testImage_maine = img_transform_512(testImage_maine)
        testImage_maine = Variable(testImage_maine.repeat(1, 1, 1, 1),
                                   requires_grad=False).type(dtype)

    # define network
    image_transformer = ImageTransformNet().type(dtype)
    optimizer = Adam(image_transformer.parameters(), LEARNING_RATE)

    loss_mse = torch.nn.MSELoss()

    # load vgg network
    vgg = Vgg16().type(dtype)

    # get training dataset
    dataset_transform = transforms.Compose([
        # scale shortest side to image_size
        transforms.Scale(IMAGE_SIZE),
        transforms.CenterCrop(IMAGE_SIZE),  # crop center image_size out
        # turn image from [0-255] to [0-1]
        transforms.ToTensor(),
        utils.normalize_tensor_transform()  # normalize with ImageNet values
    ])
    train_dataset = datasets.ImageFolder(args.dataset, dataset_transform)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE)

    # style image
    style_transform = transforms.Compose([
        # turn image from [0-255] to [0-1]
        transforms.ToTensor(),
        utils.normalize_tensor_transform()  # normalize with ImageNet values
    ])
    style = utils.load_image(args.style_image)
    style = style_transform(style)
    style = Variable(style.repeat(BATCH_SIZE, 1, 1, 1)).type(dtype)
    style_name = os.path.split(args.style_image)[-1].split('.')[0]

    # calculate gram matrices for style feature layer maps we care about
    style_features = vgg(style)
    style_gram = [utils.gram(fmap) for fmap in style_features]

    for e in range(EPOCHS):

        # track values for...
        img_count = 0
        aggregate_style_loss = 0.0
        aggregate_content_loss = 0.0
        aggregate_tv_loss = 0.0

        # train network
        image_transformer.train()
        for batch_num, (x, label) in enumerate(train_loader):
            img_batch_read = len(x)
            img_count += img_batch_read

            # zero out gradients
            optimizer.zero_grad()

            # input batch to transformer network
            x = Variable(x).type(dtype)
            y_hat = image_transformer(x)

            # get vgg features
            y_c_features = vgg(x)
            y_hat_features = vgg(y_hat)

            # calculate style loss
            y_hat_gram = [utils.gram(fmap) for fmap in y_hat_features]
            style_loss = 0.0
            for j in range(4):
                style_loss += loss_mse(y_hat_gram[j],
                                       style_gram[j][:img_batch_read])
            style_loss = STYLE_WEIGHT * style_loss
            aggregate_style_loss += style_loss.item()

            # calculate content loss (h_relu_2_2)
            recon = y_c_features[1]
            recon_hat = y_hat_features[1]
            content_loss = CONTENT_WEIGHT * loss_mse(recon_hat, recon)
            aggregate_content_loss += content_loss.item()

            # calculate total variation regularization (anisotropic version)
            # https://www.wikiwand.com/en/Total_variation_denoising
            diff_i = torch.sum(
                torch.abs(y_hat[:, :, :, 1:] - y_hat[:, :, :, :-1]))
            diff_j = torch.sum(
                torch.abs(y_hat[:, :, 1:, :] - y_hat[:, :, :-1, :]))
            tv_loss = TV_WEIGHT * (diff_i + diff_j)
            aggregate_tv_loss += tv_loss.item()

            # total loss
            total_loss = style_loss + content_loss + tv_loss

            # backprop
            total_loss.backward()
            optimizer.step()

            # print out status message
            if ((batch_num + 1) % 100 == 0):
                status = "{}  Epoch {}:  [{}/{}]  Batch:[{}]  agg_style: {:.6f}  agg_content: {:.6f}  agg_tv: {:.6f}  style: {:.6f}  content: {:.6f}  tv: {:.6f} ".format(
                    time.ctime(), e + 1, img_count, len(train_dataset),
                    batch_num + 1, aggregate_style_loss / (batch_num + 1.0),
                    aggregate_content_loss / (batch_num + 1.0),
                    aggregate_tv_loss / (batch_num + 1.0), style_loss.item(),
                    content_loss.item(), tv_loss.item())
                print(status)

            if ((batch_num + 1) % 1000 == 0) and (visualize):
                image_transformer.eval()

                if not os.path.exists("visualization"):
                    os.makedirs("visualization")
                if not os.path.exists("visualization/%s" % style_name):
                    os.makedirs("visualization/%s" % style_name)

                outputTestImage_amber = image_transformer(
                    testImage_amber).cpu()

                amber_path = "visualization/%s/amber_%d_%05d.jpg" % (
                    style_name, e + 1, batch_num + 1)
                utils.save_image(amber_path, outputTestImage_amber.data[0])

                outputTestImage_dan = image_transformer(testImage_dan).cpu()
                dan_path = "visualization/%s/dan_%d_%05d.jpg" % (
                    style_name, e + 1, batch_num + 1)
                utils.save_image(dan_path, outputTestImage_dan.data[0])

                outputTestImage_maine = image_transformer(
                    testImage_maine).cpu()
                maine_path = "visualization/%s/maine_%d_%05d.jpg" % (
                    style_name, e + 1, batch_num + 1)
                utils.save_image(maine_path, outputTestImage_maine.data[0])

                print("images saved")
                image_transformer.train()

    # save model
    image_transformer.eval()

    if use_cuda:
        image_transformer.cpu()

    if not os.path.exists("models"):
        os.makedirs("models")
    filename = "models/" + str(style_name) + "_" + \
        str(time.ctime()).replace(' ', '_') + ".model"
    torch.save(image_transformer.state_dict(), filename)

    if use_cuda:
        image_transformer.cuda()
Example #11
0
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.Resize(
            args.image_size),  # the shorter side is resize to match image_size
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),  # to tensor [0,1]
        transforms.Lambda(lambda x: x.mul(255))  # convert back to [0, 255]
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True)  # to provide a batch loader

    style_image = [f for f in os.listdir(args.style_image)]
    style_num = len(style_image)
    print(style_num)

    transformer = TransformerNet(style_num=style_num).to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose([
        transforms.Resize(args.style_size),
        transforms.CenterCrop(args.style_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])

    style_batch = []

    for i in range(style_num):
        style = utils.load_image(args.style_image + style_image[i],
                                 size=args.style_size)
        style = style_transform(style)
        style_batch.append(style)

    style = torch.stack(style_batch).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)

            if n_batch < args.batch_size:
                break  # skip to next epoch when no enough images left in the last batch of current epoch

            count += n_batch
            optimizer.zero_grad()  # initialize with zero gradients

            batch_style_id = [
                i % style_num for i in range(count - n_batch, count)
            ]
            y = transformer(x.to(device), style_id=batch_style_id)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y.to(device))
            features_x = vgg(x.to(device))
            content_loss = args.content_weight * mse_loss(
                features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[batch_style_id, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)

            if args.checkpoint_model_dir is not None and (
                    batch_id + 1) % args.checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(
                    e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(
        args.epochs) + "_" + str(time.ctime()).replace(' ', '_').replace(
            ':', '') + "_" + str(int(args.content_weight)) + "_" + str(
                int(args.style_weight)) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #12
0
    def __init__(self, args):
        self.setup_environ(args)
        self.log = True
        self.distributed = False
        self.global_rank = 0
        if args.ngpus > 1:
            self.setup_distributed(args)

#         dataset = ImageFolder('/datasets01/CelebA/CelebA/072017/')
        dataset = torch.load('celeba_dset.pth')
        nimages = len(dataset)
        nval = 1000
        inds = list(range(nimages))
        train_inds = inds[1000:]
        val_inds = inds[:1000]

        # TODO: add jitter augmentation ?
        self.train_transform = T.Compose([
            T.Resize(args.iSz),
            T.RandomHorizontalFlip(0.5),
            T.CenterCrop(args.iSz),
            T.ToTensor()
        ])
        self.val_transform = T.Compose(
            [T.Resize(args.iSz),
             T.CenterCrop(args.iSz),
             T.ToTensor()])

        self.train_dset = Subset(dataset,
                                 train_inds,
                                 transform=self.train_transform)
        self.val_dset = Subset(dataset, val_inds, transform=self.val_transform)

        if self.distributed:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                self.train_dset)
        else:
            train_sampler = None

        self.train_loader = torch.utils.data.DataLoader(
            self.train_dset,
            batch_size=args.bSz,
            shuffle=(train_sampler is None),
            num_workers=10,
            pin_memory=True,
            sampler=train_sampler)

        self.val_loader = torch.utils.data.DataLoader(
            self.val_dset, batch_size=50, num_workers=10)  # 1000 samples

        self.model = pix2vec(args)
        # change normalization
        if args.normalization != 'None':
            kwargs = {}
            if args.normalization == 'GN':
                kwargs['nGroups'] = 32
            change_norm(self.model,
                        normType=args.normalization,
                        verbose=0,
                        **kwargs)

        self.model = self.model.cuda()

        if self.distributed:
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[self.local_rank],
                output_device=self.local_rank,
                find_unused_parameters=True)


#         if self.log:
#             print(f'| {self.model}')
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=args.lr,
                                          weight_decay=args.wd,
                                          betas=(0.9, 0.999),
                                          amsgrad=False)
        self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
            self.optimizer, milestones=args.milestones, gamma=args.gamma)

        self.useVgg = args.contentW > 0 or args.styleW > 0
        if self.useVgg > 0:
            self.vgg = Vgg16(
                requires_grad=False).to('cuda').eval()  # TODO: eval ???
            self.mse_loss = torch.nn.MSELoss()

        self.iteration = 0
        self.bestPSNR = 0

        if self.log:
            self.writer = SummaryWriter(args.logdir)
            self.writer.add_text('args', str(args), 0)

        self.args = args

        if args.ckptPath:
            self.load(args.ckptPath)

        print(' | Done init trainer')
Example #13
0
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    train_loader = check_dataset(args)
    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])

    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    running_avgs = OrderedDict()

    def step(engine, batch):

        x, _ = batch
        x = x.to(device)

        n_batch = len(x)

        optimizer.zero_grad()

        y = transformer(x)

        x = utils.normalize_batch(x)
        y = utils.normalize_batch(y)

        features_x = vgg(x)
        features_y = vgg(y)

        content_loss = args.content_weight * mse_loss(features_y.relu2_2,
                                                      features_x.relu2_2)

        style_loss = 0.0
        for ft_y, gm_s in zip(features_y, gram_style):
            gm_y = utils.gram_matrix(ft_y)
            style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
        style_loss *= args.style_weight

        total_loss = content_loss + style_loss
        total_loss.backward()
        optimizer.step()

        return {
            "content_loss": content_loss.item(),
            "style_loss": style_loss.item(),
            "total_loss": total_loss.item()
        }

    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(args.checkpoint_model_dir,
                                         "checkpoint",
                                         n_saved=10,
                                         require_empty=False,
                                         create_dir=True)
    progress_bar = Progbar(loader=train_loader, metrics=running_avgs)

    trainer.add_event_handler(
        event_name=Events.EPOCH_COMPLETED(every=args.checkpoint_interval),
        handler=checkpoint_handler,
        to_save={"net": transformer},
    )
    trainer.add_event_handler(event_name=Events.ITERATION_COMPLETED,
                              handler=progress_bar)
    trainer.run(train_loader, max_epochs=args.epochs)
Example #14
0
def fast_train(args):
    """Fast training"""

    device = torch.device("cuda" if args.cuda else "cpu")

    transformer = TransformerNet().to(device)
    if args.model:
        transformer.load_state_dict(torch.load(args.model))
    vgg = Vgg16(requires_grad=False).to(device)
    global mse_loss
    mse_loss = torch.nn.MSELoss()

    content_weight = args.content_weight
    style_weight = args.style_weight
    lr = args.lr

    content_transform = transforms.Compose([
        transforms.Resize(args.content_size),
        transforms.CenterCrop(args.content_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))])
    content_dataset = datasets.ImageFolder(args.content_dataset, content_transform)
    content_loader = DataLoader(content_dataset, 
                                batch_size=args.iter_batch_size, 
                                sampler=InfiniteSamplerWrapper(content_dataset),
                                num_workers=args.n_workers)
    content_loader = iter(content_loader)
    style_transform = transforms.Compose([
            transforms.Resize((args.style_size, args.style_size)),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.mul(255))])

    style_image = utils.load_image(args.style_image)
    style_image = style_transform(style_image)
    style_image = style_image.unsqueeze(0).to(device)
    features_style = vgg(utils.normalize_batch(style_image.repeat(args.iter_batch_size, 1, 1, 1)))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    if args.only_in:
        optimizer = Adam([param for (name, param) in transformer.named_parameters() if "in" in name], lr=lr)
    else:
        optimizer = Adam(transformer.parameters(), lr=lr)

    for i in trange(args.update_step):
        contents = content_loader.next()[0].to(device)
        features_contents = vgg(utils.normalize_batch(contents))

        transformed = transformer(contents)
        features_transformed = vgg(utils.standardize_batch(transformed))
        loss, c_loss, s_loss = loss_fn(features_transformed, features_contents, gram_style, content_weight, style_weight)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # save model
    transformer.eval().cpu()
    style_name = os.path.basename(args.style_image).split(".")[0]
    save_model_filename = style_name + ".pth"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)
Example #15
0
label_weight = [9.81, 3.98]
std = [.229, .224, .225]
mean = [.485, .456, .406]

os.system('rm -rf ./runs2/*')
writer = SummaryWriter('./runs2/' + datetime.now().strftime('%B%d  %H:%M:%S'))

if not os.path.exists('./runs2'):
    os.mkdir('./runs2')

if not os.path.exists(check_dir):
    os.mkdir(check_dir)

# models
if 'vgg' == opt.i:
    feature = Vgg16(pretrained=True)
elif 'resnet' == opt.i:
    feature = resnet50(pretrained=True)
elif 'densenet' == opt.i:
    feature = densenet121(pretrained=True)
feature.cuda()

classifier = Classifier(opt.i)
classifier.cuda()

if resume_ep >= 0:
    feature_param_file = glob.glob('%s/feature-epoch-%d*.pth' %
                                   (check_dir, resume_ep))
    classifier_param_file = glob.glob('%s/classifier-epoch-%d*.pth' %
                                      (check_dir, resume_ep))
    feature.load_state_dict(torch.load(feature_param_file[0]))
Example #16
0
def train(args):
    """Meta train the model"""

    device = torch.device("cuda" if args.cuda else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    # first move parameters to GPU
    transformer = TransformerNet().to(device)
    vgg = Vgg16(requires_grad=False).to(device)
    global optimizer
    optimizer = Adam(transformer.parameters(), args.meta_lr)
    global mse_loss
    mse_loss = torch.nn.MSELoss()

    content_loader, style_loader, query_loader = get_data_loader(args)

    content_weight = args.content_weight
    style_weight = args.style_weight
    lr = args.lr

    writer = SummaryWriter(args.log_dir)

    for iteration in trange(args.max_iter):
        transformer.train()
        
        # bookkeeping
        # using state_dict causes problems, use named_parameters instead
        all_meta_grads = []
        avg_train_c_loss = 0.0
        avg_train_s_loss = 0.0
        avg_train_loss = 0.0
        avg_eval_c_loss = 0.0
        avg_eval_s_loss = 0.0
        avg_eval_loss = 0.0

        contents = content_loader.next()[0].to(device)
        features_contents = vgg(utils.normalize_batch(contents))
        querys = query_loader.next()[0].to(device)
        features_querys = vgg(utils.normalize_batch(querys))

        # learning rate scheduling
        lr = args.lr / (1.0 + iteration * 2.5e-5)
        meta_lr = args.meta_lr / (1.0 + iteration * 2.5e-5)
        for param_group in optimizer.param_groups:
            param_group['lr'] = meta_lr

        for i in range(args.meta_batch_size):
            # sample a style
            style = style_loader.next()[0].to(device)
            style = style.repeat(args.iter_batch_size, 1, 1, 1)
            features_style = vgg(utils.normalize_batch(style))
            gram_style = [utils.gram_matrix(y) for y in features_style]

            fast_weights = OrderedDict((name, param) for (name, param) in transformer.named_parameters() if re.search(r'in\d+\.', name))
            for j in range(args.meta_step):
                # run forward transformation on contents
                transformed = transformer(contents, fast_weights)

                # compute loss
                features_transformed = vgg(utils.standardize_batch(transformed))
                loss, c_loss, s_loss = loss_fn(features_transformed, features_contents, gram_style, content_weight, style_weight)

                # compute grad
                grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)

                # update fast weights
                fast_weights = OrderedDict((name, param - lr * grad) for ((name, param), grad) in zip(fast_weights.items(), grads))
            
            avg_train_c_loss += c_loss.item()
            avg_train_s_loss += s_loss.item()
            avg_train_loss += loss.item()

            # run forward transformation on querys
            transformed = transformer(querys, fast_weights)
            
            # compute loss
            features_transformed = vgg(utils.standardize_batch(transformed))
            loss, c_loss, s_loss = loss_fn(features_transformed, features_querys, gram_style, content_weight, style_weight)
            
            grads = torch.autograd.grad(loss / args.meta_batch_size, transformer.parameters())
            all_meta_grads.append({name: g for ((name, _), g) in zip(transformer.named_parameters(), grads)})

            avg_eval_c_loss += c_loss.item()
            avg_eval_s_loss += s_loss.item()
            avg_eval_loss += loss.item()
        
        writer.add_scalar("Avg_Train_C_Loss", avg_train_c_loss / args.meta_batch_size, iteration + 1)
        writer.add_scalar("Avg_Train_S_Loss", avg_train_s_loss / args.meta_batch_size, iteration + 1)
        writer.add_scalar("Avg_Train_Loss", avg_train_loss / args.meta_batch_size, iteration + 1)
        writer.add_scalar("Avg_Eval_C_Loss", avg_eval_c_loss / args.meta_batch_size, iteration + 1)
        writer.add_scalar("Avg_Eval_S_Loss", avg_eval_s_loss / args.meta_batch_size, iteration + 1)
        writer.add_scalar("Avg_Eval_Loss", avg_eval_loss / args.meta_batch_size, iteration + 1)

        # compute dummy loss to refresh buffer
        transformed = transformer(querys)
        features_transformed = vgg(utils.standardize_batch(transformed))
        dummy_loss, _, _ = loss_fn(features_transformed, features_querys, gram_style, content_weight, style_weight)

        meta_updates(transformer, dummy_loss, all_meta_grads)

        if args.checkpoint_model_dir is not None and (iteration + 1) % args.checkpoint_interval == 0:
            transformer.eval().cpu()
            ckpt_model_filename = "iter_" + str(iteration + 1) + ".pth"
            ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename)
            torch.save(transformer.state_dict(), ckpt_model_path)
            transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "Final_iter_" + str(args.max_iter) + "_" + \
                          str(args.content_weight) + "_" + \
                          str(args.style_weight) + "_" + \
                          str(args.lr) + "_" + \
                          str(args.meta_lr) + "_" + \
                          str(args.meta_batch_size) + "_" + \
                          str(args.meta_step) + "_" + \
                          time.ctime() + ".pth"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print "Done, trained model saved at {}".format(save_model_path)
Example #17
0
LEARNING_RATE = 0.01
CONTENT_WEIGHT = 1
STYLE_WEIGHT = 1e10
TV_WEIGHT = 0.0001

transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
])

style_transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
])

vgg = Vgg16(requires_grad=False).cuda()  # vgg16 model

I = utils.scale_image(filename=STYLE_IMG_PATH, size=128, scale=512)

I = np.array(I)
plt.imshow(I)
plt.show()

style_img = utils.load_image(filename=STYLE_IMG_PATH, size=IMAGE_SIZE)
# style_img = utils.image_compose(IMG=style_img, IMAGE_ROW=4, IMAGE_COLUMN=4, IMAGE_SIZE=128)
content_img = utils.load_image(filename=CONTENT_IMG_PATH, size=IMAGE_SIZE)

style_img = style_transform(style_img)
content_img = transform(content_img)

style_img = style_img.repeat(BATCH_SIZE, 1, 1, 1).cuda()  # make fake batch
Example #18
0
def train():
    device = torch.device("cuda")

    np.random.seed(random_seed)
    torch.manual_seed(random_seed)

    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
    ])

    train_dataset = datasets.ImageFolder(dataset_path, transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size)

    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), lr)
    mse_loss = torch.nn.MSELoss()

    if resume_TransformerNet_from_file:
        if os.path.isfile(TransformerNet_path):
            print("=> loading checkpoint '{}'".format(TransformerNet_path))
            TransformerNet_par = torch.load(TransformerNet_path)
            for k in list(TransformerNet_par.keys()):
                if re.search(r'in\d+\.running_(mean|var)$', k):
                    del TransformerNet_par[k]
            transformer.load_state_dict(TransformerNet_par)
            print("=> loaded checkpoint '{}'".format(TransformerNet_path))
        else:
            print("=> no checkpoint found at '{}'".format(TransformerNet_path))

    vgg = Vgg16(requires_grad=False).to(device)
    style = Image.open(style_image_path)
    style = transform(style)
    style = style.repeat(batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    model_fcrn = FCRN_for_transfer(batch_size=batch_size,
                                   requires_grad=False).to(device)
    model_fcrn_par = torch.load(FCRN_path)
    #start_epoch = model_fcrn_par['epoch']
    model_fcrn.load_state_dict(model_fcrn_par['state_dict'])
    print("=> loaded checkpoint '{}' (epoch {})".format(
        FCRN_path, model_fcrn_par['epoch']))

    for e in range(epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_depth_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            depth_y = model_fcrn(y)
            depth_x = model_fcrn(x)

            content_loss = content_weight * mse_loss(features_y.relu2_2,
                                                     features_x.relu2_2)
            depth_loss = depth_weight * mse_loss(depth_y, depth_x)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= style_weight

            total_loss = content_loss + depth_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_depth_loss += depth_loss.item()
            agg_style_loss += style_loss.item()

            if (batch_id + 1) % log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tdepth: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_depth_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)

            if checkpoint_model_dir is not None and (
                    batch_id + 1) % checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(
                    e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + "_" + str(
            content_weight) + "_" + str(style_weight) + ".model"
    save_model_path = os.path.join(save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #19
0
parser.add_argument('--i', default='vgg')  # dataset
parser.add_argument('--test_dir', default='/home/zeng/data/datasets/clshand/val')  # dataset
parser.add_argument('--feat', default='/home/zeng/handseg/parameters_cls/feature-epoch-19-step-365.pth')
parser.add_argument('--cls', default='/home/zeng/handseg/parameters_cls/classifier-epoch-19-step-365.pth')
parser.add_argument('--b', type=int, default=16)  # batch size
opt = parser.parse_args()
print(opt)

test_dir = opt.test_dir
feature_param_file = opt.feat
class_param_file = opt.cls
bsize = opt.b

# models
if 'vgg' == opt.i:
    feature = Vgg16()
elif 'resnet' == opt.i:
    feature = resnet50()
elif 'densenet' == opt.i:
    feature = densenet121()
feature.cuda()
feature.load_state_dict(torch.load(feature_param_file))

classifier = Classifier(opt.i)
classifier.cuda()
classifier.load_state_dict(torch.load(class_param_file))

loader = torch.utils.data.DataLoader(
    MyClsTestData(test_dir, transform=True),
    batch_size=bsize, shuffle=True, num_workers=4, pin_memory=True)
Example #20
0
def train(args):
    if torch.cuda.is_available():
        print('CUDA available, using GPU.')
        device = torch.device('cuda')
    else:
        print('GPU training unavailable... using CPU.')
        device = torch.device('cpu')

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)



    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])


    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    # Image transformation network.
    transformer = TransformerNet()

    if args.model:
        state_dict = torch.load(args.model)
        transformer.load_state_dict(state_dict)

    transformer.to(device)

    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    # Loss Network: VGG16
    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            # CUDA if available
            x = x.to(device)

            # Transform image
            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            # Feature map of original image
            features_x = vgg(x)
            # Feature Map of transformed image
            features_y = vgg(y)

            # Difference between transformed image, original image.
            # Changed to pull from features_.relu3_3 vs .relu2_2
            content_loss = args.content_weight * mse_loss(features_y.relu3_3, features_x.relu3_3)

            # Compute gram matrix (dot product across each dimension G(4,3) = F4 * F3)
            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            if True: #(batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                                  agg_content_loss / (batch_id + 1),
                                  agg_style_loss / (batch_id + 1),
                                  (agg_content_loss + agg_style_loss) / (batch_id + 1)
                )
                print(mesg)

            if args.checkpoint_model_dir is not None and (batch_id + 1) % args.checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
        args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #21
0
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    transformer = TransformerNet(args.alpha).to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16().to(device)
    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            content_loss = args.content_weight * mse_loss(
                features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward(retain_graph=True)
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)

            if args.checkpoint_model_dir is not None and (
                (batch_id + 1) %
                    int(args.checkpoint_interval / args.batch_size)) == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "checkpoint_" + str(
                    batch_id + 1
                ) + ".pth"  #"ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                os.system("python neural_style/neural_style.py eval \
                            --model ~/Documents/data/models/pytorch-checkpoints/"
                          + ckpt_model_filename + " \
                            --content-image ~/Documents/data/images/test.jpg \
                            --output-image ~/Documents/data/images/pytorch/stylized-test_"
                          + str(batch_id + 1) + ".jpg \
                            --alpha " + str(args.alpha) + " \
                            --cuda 0")
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + "_" + str(
            args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #22
0
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.Scale(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_folder = DataLoader(train_dataset, batch_size=args.batch_size)

    transformer = TransformerNet()
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False)
    style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1)

    if args.cuda:
        transformer.cuda()
        vgg.cuda()
        style = style.cuda()

    style_v = Variable(style)
    style_v = utils.normalize_batch(style_v)
    features_style = vgg(style_v)
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0. 
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()
            x = Variable(x)
            if args.cuda:
                x = x.cuda()
            
            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            content_loss = args.content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2)

            style_loss = 0
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.data[0]
            agg_style_loss += style_loss.data[0]

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                                  agg_content_loss / (batch_id + 1),
                                  agg_style_loss / (batch_id + 1),
                                  (agg_contetn_loss + agg_style_loss) / (batch_id + 1)
                )
                print(mesg)
Example #23
0
style.size()
# Out[20]: torch.Size([4, 3, 391, 391])


def normalize_batch(batch):
    # normalize using imagenet mean and std
    mean = batch.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
    std = batch.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
    batch = batch.div_(255.0)
    return (batch - mean) / std


import vgg
from vgg import Vgg16

vgg = Vgg16(requires_grad=False).to(device)
# Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /Users/chenqy/.cache/torch/checkpoints/vgg16-397923af.pth
# 100.0%

batch = style
mean = batch.new_tensor([0.485, 0.456,
                         0.406]).view(-1, 1, 1)  # mean的torch.Size([3, 1, 1])
std = batch.new_tensor([0.229, 0.224,
                        0.225]).view(-1, 1, 1)  # std的torch.Size([3, 1, 1])
batch = batch.div_(255.0)  # 归一化到[0,1]

normalizeresult = (batch - mean) / std
# tensor([[[[-1.3644, -0.9705,  0.3823,  ..., -0.7308, -0.7650, -0.7308],
#           [ 0.0741, -1.1760, -0.8335,  ...,  0.6392,  0.3309,  0.2111],
#           [ 0.5707,  0.4851, -0.9877,  ...,  0.8276,  0.4679, -1.3644],
#           ...,
Example #24
0
def train(style_image, dataset_path):
    print('Training function started...')
    torch.cuda.empty_cache()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    image_size = 256
    style_weight = 1e10
    content_weight = 1e5
    lr = 1e-3
    batch_size = 3
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])

    train_dataset = datasets.ImageFolder(dataset_path, transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size)

    transformer = TransformerNet().to(device)

    optimizer = Adam(transformer.parameters(), lr=lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16().to(device)
    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    style = load_image(style_image)
    style = style_transform(style)
    style = style.repeat(batch_size, 1, 1, 1).to(device)

    features_style = vgg(normalize_batch(style))
    gram_style = [gram_matrix(y) for y in features_style]
    epochs = 2
    print('Starting epochs...')
    for e in range(epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            y = transformer(x)

            y = normalize_batch(y)
            x = normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            content_loss = content_weight * mse_loss(features_y.relu2_2,
                                                     features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])

            style_loss *= style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()
            log_interval = 2000
            if (batch_id + 1) % log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)

    # save model
    transformer.eval().cpu()
    save_model_path = 'models/outpost.pth'
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #25
0
import torch
from torch.autograd import Variable
import torch.nn.functional as F

import time
import numpy as np
from PIL import Image
from vgg import Vgg16
from model import ActionClassifier
from database.database import Dataset
import os

# Initialize Networks
vgg = Vgg16(requires_grad=False)
vgg = vgg.cuda()
classifier = ActionClassifier()
classifier = classifier.cuda()


# Load dataset
dataset = Dataset()
dataset.initialize()

num_epoch = 500
num_train_data = len(dataset)
batch_size = 64

optimizer = torch.optim.Adam(classifier.parameters(), lr=0.0002, betas=(0.5, 0.999))

tt = time.time()
Example #26
0
File: WDNet.py Project: MRUIL/WDNet
    def train(self):
        self.train_hist = {}
        self.train_hist['D_loss'] = []
        self.train_hist['G_loss'] = []
        self.train_hist['per_epoch_time'] = []
        self.train_hist['total_time'] = []

        #self.y_real_, self.y_fake_ = torch.ones(self.batch_size, 1), torch.zeros(self.batch_size, 1)
        #if self.gpu_mode:
        #self.y_real_, self.y_fake_ = self.y_real_.cuda(), self.y_fake_.cuda()
        vgg = Vgg16().type(torch.cuda.FloatTensor)
        self.D.train()
        print('training start!!')
        start_time = time.time()
        writer = SummaryWriter(log_dir='log/ex_WDNet')
        lenth = self.data_loader.dataset.__len__()
        iter_all = 0
        for epoch in range(self.epoch):
            self.G.train()
            epoch_start_time = time.time()
            for iter, (x_, y_, mask, balance, alpha,
                       w) in enumerate(self.data_loader):
                iter_all += 1  #iter+epoch*(lenth//self.batch_size)
                if iter == lenth // self.batch_size:
                    break
                #y_vec_ = torch.zeros((self.batch_size, self.class_num)).scatter_(1, y_.type(torch.LongTensor).unsqueeze(1), 1)
                #y_fill_ = y_vec_.unsqueeze(2).unsqueeze(3).expand(self.batch_size, self.class_num, self.input_size, self.input_size)
                if self.gpu_mode:
                    x_, y_, mask, balance, alpha, w = x_.cuda(), y_.cuda(
                    ), mask.cuda(), balance.cuda(), alpha.cuda(), w.cuda()
                    #x_, z_, y_vec_, y_fill_ = x_.cuda(), z_.cuda(), y_vec_.cuda(), y_fill_.cuda()

                # update D network
                if ((iter + 1) % 3) == 0:
                    self.D_optimizer.zero_grad()

                    D_real = self.D(x_, y_)
                    D_real_loss = self.BCE_loss(D_real,
                                                torch.ones_like(D_real))

                    G_, g_mask, g_alpha, g_w, I_watermark = self.G(x_)
                    D_fake = self.D(x_, G_)
                    D_fake_loss = self.BCE_loss(D_fake,
                                                torch.zeros_like(D_fake))

                    D_loss = 0.5 * D_real_loss + 0.5 * D_fake_loss
                    #self.train_hist['D_loss'].append(D_loss.item())
                    D_writer = D_loss.item()
                    D_loss.backward()
                    self.D_optimizer.step()

                # update G network
                self.G_optimizer.zero_grad()

                G_, g_mask, g_alpha, g_w, I_watermark = self.G(x_)
                D_fake = self.D(x_, G_)
                G_loss = self.BCE_loss(D_fake, torch.ones_like(D_fake))
                feature_G = vgg(G_)
                feature_real = vgg(y_)
                vgg_loss = 0.0
                for j in range(3):
                    vgg_loss += self.loss_mse(feature_G[j], feature_real[j])
                #self.train_hist['G_loss'].append(G_loss.item())
                mask_loss = self.l1loss(
                    g_mask * balance,
                    mask * balance) * balance.size(0) * balance.size(
                        1) * balance.size(2) * balance.size(3) / balance.sum()
                w_loss = self.l1loss(
                    g_w * mask, w * mask) * mask.size(0) * mask.size(
                        1) * mask.size(2) * mask.size(3) / mask.sum()
                alpha_loss = self.l1loss(
                    g_alpha * mask, alpha * mask) * mask.size(0) * mask.size(
                        1) * mask.size(2) * mask.size(3) / mask.sum()
                I_watermark_loss = self.l1loss(
                    I_watermark * mask, y_ * mask) * mask.size(0) * mask.size(
                        1) * mask.size(2) * mask.size(3) / mask.sum()
                I_watermark2_loss = self.l1loss(
                    G_ * mask, y_ * mask) * mask.size(0) * mask.size(
                        1) * mask.size(2) * mask.size(3) / mask.sum()
                G_writer = G_loss.data
                G_loss = G_loss + 10.0 * mask_loss + 10.0 * w_loss + 10.0 * alpha_loss + 50.0 * (
                    0.7 * I_watermark2_loss +
                    0.3 * I_watermark_loss) + 1e-2 * vgg_loss
                G_loss.backward()
                self.G_optimizer.step()
                if ((iter + 1) % 100) == 0:
                    writer.add_scalar('G_Loss', G_writer, iter_all)
                    writer.add_scalar('D_Loss', D_loss.item(), iter_all)
                    writer.add_scalar('W_Loss', w_loss, iter_all)
                    writer.add_scalar('alpha_Loss', alpha_loss, iter_all)
                    writer.add_scalar('mask_Loss', mask_loss, iter_all)
                    writer.add_scalar('I_watermark_Loss', I_watermark_loss,
                                      iter_all)
                    writer.add_scalar('I_watermark2_Loss', I_watermark2_loss,
                                      iter_all)
                    writer.add_scalar('vgg_Loss', vgg_loss, iter_all)
                if ((iter + 1) % 100) == 0:
                    print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
                          ((epoch + 1),
                           (iter + 1), self.data_loader.dataset.__len__() //
                           self.batch_size, D_loss.item(), G_writer))
            self.save()
        print("Training finish!... save training results")

        self.save()
Example #27
0
def train(args):
    # Define the device
    device = torch.device("cuda" if args.cuda else "cpu")

    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # transform and dataset
    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    # model, optimizer, and mse-loss
    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    # feature and style loss
    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    # load and preprocess the style reference image
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    # The features and gram matrix of referenced style image
    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            content_loss = args.content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                                  agg_content_loss / (batch_id + 1),
                                  agg_style_loss / (batch_id + 1),
                                  (agg_content_loss + agg_style_loss) / (batch_id + 1)
                )
                print(mesg)

            if args.checkpoint_model_dir is not None and (batch_id + 1) % args.checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
        args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.Scale(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    transformer = TransformerNet()
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False)
    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1)

    if args.cuda:
        transformer.cuda()
        vgg.cuda()
        style = style.cuda()

    style_v = Variable(style)
    style_v = utils.normalize_batch(style_v)
    features_style = vgg(style_v)
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()
            x = Variable(x)
            if args.cuda:
                x = x.cuda()

            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            content_loss = args.content_weight * mse_loss(
                features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.data[0]
            agg_style_loss += style_loss.data[0]

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)
                niter = e * len(train_dataset) + batch_id
                writer.add_scalar('content loss',
                                  agg_content_loss / (batch_id + 1), niter)
                writer.add_scalar('style loss',
                                  agg_style_loss / (batch_id + 1), niter)
                writer.add_scalar(
                    'total loss', agg_content_loss /
                    (agg_content_loss + agg_style_loss) / (batch_id + 1),
                    niter)
            if args.checkpoint_model_dir is not None and (
                    batch_id + 1) % args.checkpoint_interval == 0:
                transformer.eval()
                if args.cuda:
                    transformer.cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(
                    e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                if args.cuda:
                    transformer.cuda()
                transformer.train()

    # save model
    transformer.eval()
    if args.cuda:
        transformer.cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + "_" + str(
            args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #29
0
def train(start_epoch=0):
    np.random.seed(enums.seed)
    torch.manual_seed(enums.seed)

    if enums.cuda:
        torch.cuda.manual_seed(enums.seed)

    transform = transforms.Compose([
        transforms.Resize(enums.image_size),
        transforms.CenterCrop(enums.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(enums.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=enums.batch_size)

    transformer = TransformerNet()
    #transformer = torch.nn.DataParallel(transformer)
    optimizer = Adam(transformer.parameters(), enums.lr)
    if enums.subcommand == 'resume':
        ckpt_state = torch.load(enums.checkpoint_model)
        transformer.load_state_dict(ckpt_state['state_dict'])
        start_epoch = ckpt_state['epoch']
        optimizer.load_state_dict(ckpt_state['optimizer'])

    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False)
    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])

    if enums.cuda:
        transformer.cuda()
        vgg.cuda()

    all_style_img_paths = [
        os.path.join(enums.style_image_dir, f)
        for f in os.listdir(enums.style_image_dir)
    ]
    all_style_grams = {}
    for i, style_img in enumerate(all_style_img_paths):
        style = utils.load_image(style_img, size=enums.style_size)
        style = style_transform(style)
        style = style.repeat(
            enums.batch_size, 1, 1,
            1)  # can try with expand but unsure of backprop effects
        if enums.cuda:
            style = style.cuda()
        style_v = Variable(style)
        style_v = utils.normalize_batch(style_v)
        features_style = vgg(style_v)
        gram_style = [utils.gram_matrix(y) for y in features_style]
        all_style_grams[i] = gram_style

    for e in range(start_epoch, enums.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            idx = random.randint(0, enums.num_styles - 1)  # 0 to num_styles-1
            # S = torch.zeros(enums.num_styles, 1) # s,1 vector
            # S[idx] = 1 # one-hot vec for rand chosen style

            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()
            x = Variable(x)
            if enums.cuda:
                #S = S.cuda()
                x = x.cuda()

            y = transformer(x, idx)
            #print e, batch_id

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)
            gram_style = all_style_grams[idx]

            content_loss = enums.content_weight * mse_loss(
                features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= enums.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.data[0]
            agg_style_loss += style_loss.data[0]

            if (batch_id + 1) % enums.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)
        # del content_loss, style_loss, S, x, y, style, features_x, features_y

        if enums.checkpoint_model_dir is not None and (
                e + 1) % enums.checkpoint_interval == 0:
            # transformer.eval()
            if enums.cuda:
                transformer.cpu()
            ckpt_model_filename = "ckpt_epoch_" + str(e + 1) + ".pth"
            ckpt_model_path = os.path.join(enums.checkpoint_model_dir,
                                           ckpt_model_filename)
            save_checkpoint(
                {
                    'epoch': e + 1,
                    'state_dict': transformer.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, ckpt_model_path)
            if enums.cuda:
                transformer.cuda()
            # transformer.train()

    # save model
    # transformer.eval()
    if enums.cuda:
        transformer.cpu()
    save_model_filename = "epoch_" + str(enums.epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + "_" + str(
            enums.content_weight) + "_" + str(enums.style_weight) + ".model"
    save_model_path = os.path.join(enums.save_model_dir, save_model_filename)
    save_checkpoint(
        {
            'epoch': e + 1,
            'state_dict': transformer.state_dict(),
            'optimizer': optimizer.state_dict()
        }, save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #30
0
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])

    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    test_image = utils.load_image(args.test_image)
    test_image = style_transform(test_image)
    test_image = test_image.unsqueeze(0).to(device)

    running_avgs = OrderedDict()
    output_stream = sys.stdout
    alpha = 0.98

    def step(engine, batch):

        x, _ = batch
        x = x.to(device)

        n_batch = len(x)

        transformer.zero_grad()

        y = transformer(x)

        x = utils.normalize_batch(x)
        y = utils.normalize_batch(y)

        features_x = vgg(x)
        features_y = vgg(y)

        content_loss = args.content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2)

        style_loss = 0.
        for ft_y, gm_s in zip(features_y, gram_style):
            gm_y = utils.gram_matrix(ft_y)
            style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
        style_loss *= args.style_weight

        total_loss = content_loss + style_loss
        total_loss.backward()
        optimizer.step()

        return {
            'content_loss': content_loss.item(),
            'style_loss': style_loss.item(),
            'total_loss': total_loss.item()
        }

    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(args.checkpoint_model_dir, 'ckpt_epoch_',
                                         save_interval=args.checkpoint_interval,
                                         n_saved=10, require_empty=False, create_dir=True)

    @trainer.on(Events.ITERATION_COMPLETED)
    def update_logs(engine):
        for k, v in engine.state.output.items():
            old_v = running_avgs.get(k, v)
            new_v = alpha * old_v + (1 - alpha) * v
            running_avgs[k] = new_v

    @trainer.on(Events.ITERATION_COMPLETED)
    def print_logs(engine):

        num_seen = engine.state.iteration - len(train_loader) * (engine.state.epoch - 1)

        percent_seen = 100 * (float(num_seen / len(train_loader)))
        percentages = list(range(0, 110, 10))

        if int(percent_seen) == 100:
            progress = 0
            equal_to = 10
            sub = 0
        else:
            sub = 1
            progress = 1
            equal_to = np.max(np.where([percent < percent_seen for percent in percentages])[0])

        bar = '[' + '=' * equal_to + '>' * progress + ' ' * (10 - equal_to - sub) + ']'

        message = 'Epoch {epoch} | {percent_seen:.2f}% | {bar}'.format(epoch=engine.state.epoch,
                                                                       percent_seen=percent_seen,
                                                                       bar=bar)
        for key, value in running_avgs.items():
            message += ' | {name}: {value:.2e}'.format(name=key, value=value)

        message += '\r'

        output_stream.write(message)
        output_stream.flush()

    @trainer.on(Events.EPOCH_COMPLETED)
    def complete_progress(engine):
        output_stream.write('\n')

    @trainer.on(Events.EPOCH_COMPLETED)
    def stylize_image(engine):
        path = os.path.join(args.stylized_test_dir, STYLIZED_IMG_FNAME.format(engine.state.epoch))
        content_image = utils.load_image(args.test_image, scale=None)
        content_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.mul(255))
        ])
        content_image = content_transform(content_image)
        content_image = content_image.unsqueeze(0).to(device)

        with torch.no_grad():
            output = transformer(content_image).cpu()

        utils.save_image(path, output[0])

    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler,
                              to_save={'net': transformer})
    trainer.run(train_loader, max_epochs=args.epochs)