示例#1
0
def get_GAN_AB_model(folder_model, model_name, device):          
    n_residual_blocks = 9 # this should be the same values used in training the G_AB model    
    G_AB = GeneratorResNet(input_shape=(3,0), num_residual_blocks = n_residual_blocks)        
    G_AB.load_state_dict(torch.load(folder_model + model_name,  map_location=device ),  )    
    
    if cuda: 
        G_AB = G_AB.to(device)
    return G_AB
示例#2
0
def get_generator_model():
    generator = GeneratorResNet(img_shape=img_shape,
                                res_blocks=residual_blocks,
                                c_dim=c_dim)
    generator.load_state_dict(
        torch.load(PATH_G, map_location=torch.device('cpu')))
    generator.eval()
    return generator
示例#3
0
    def _prepare_generating(self):
        """Prepare generating

        Make tensorflow's graph.
        """
        self.z_size = self.style_z_size + self.char_embedding_n

        if self.arch == 'DCGAN':
            generator = GeneratorDCGAN(img_size=(self.img_width,
                                                 self.img_height),
                                       img_dim=self.img_dim,
                                       z_size=self.z_size,
                                       layer_n=4,
                                       k_size=3,
                                       smallest_hidden_unit_n=64,
                                       is_bn=False)
        elif self.arch == 'ResNet':
            generator = GeneratorResNet(k_size=3, smallest_unit_n=64)

        if FLAGS.generate_walk:
            style_embedding_np = np.random.uniform(
                -1, 1, (FLAGS.char_img_n // self.walk_step,
                        self.style_z_size)).astype(np.float32)
        else:
            style_embedding_np = np.random.uniform(
                -1, 1,
                (self.style_ids_n, self.style_z_size)).astype(np.float32)

        with tf.variable_scope('embeddings'):
            style_embedding = tf.Variable(style_embedding_np,
                                          name='style_embedding')
        self.style_ids_x = tf.placeholder(tf.int32, (self.batch_size, ),
                                          name='style_ids_x')
        self.style_ids_y = tf.placeholder(tf.int32, (self.batch_size, ),
                                          name='style_ids_y')
        self.style_ids_alpha = tf.placeholder(tf.float32, (self.batch_size, ),
                                              name='style_ids_alpha')
        self.char_ids_x = tf.placeholder(tf.int32, (self.batch_size, ),
                                         name='char_ids_x')
        self.char_ids_y = tf.placeholder(tf.int32, (self.batch_size, ),
                                         name='char_ids_y')
        self.char_ids_alpha = tf.placeholder(tf.float32, (self.batch_size, ),
                                             name='char_ids_alpha')

        # If sum of (style/char)_ids is less than -1, z is generated from uniform distribution
        style_z_x = tf.cond(
            tf.less(tf.reduce_sum(self.style_ids_x),
                    0), lambda: tf.random_uniform(
                        (self.batch_size, self.style_z_size), -1, 1),
            lambda: tf.nn.embedding_lookup(style_embedding, self.style_ids_x))
        style_z_y = tf.cond(
            tf.less(tf.reduce_sum(self.style_ids_y),
                    0), lambda: tf.random_uniform(
                        (self.batch_size, self.style_z_size), -1, 1),
            lambda: tf.nn.embedding_lookup(style_embedding, self.style_ids_y))
        style_z = style_z_x * tf.expand_dims(1. - self.style_ids_alpha, 1) \
            + style_z_y * tf.expand_dims(self.style_ids_alpha, 1)
        char_z_x = tf.one_hot(self.char_ids_x, self.char_embedding_n)
        char_z_y = tf.one_hot(self.char_ids_y, self.char_embedding_n)
        char_z = char_z_x * tf.expand_dims(1. - self.char_ids_alpha, 1) \
            + char_z_y * tf.expand_dims(self.char_ids_alpha, 1)

        z = tf.concat([style_z, char_z], axis=1)

        self.generated_imgs = generator(z, is_train=False)

        if FLAGS.gpu_ids == "":
            sess_config = tf.ConfigProto(device_count={"GPU": 0},
                                         log_device_placement=True)
        else:
            sess_config = tf.ConfigProto(gpu_options=tf.GPUOptions(
                visible_device_list=FLAGS.gpu_ids))
        self.sess = tf.Session(config=sess_config)
        self.sess.run(tf.global_variables_initializer())

        if FLAGS.generate_walk:
            var_list = [
                var for var in tf.global_variables()
                if 'embedding' not in var.name
            ]
        else:
            var_list = [var for var in tf.global_variables()]
        pretrained_saver = tf.train.Saver(var_list=var_list)
        checkpoint = tf.train.get_checkpoint_state(self.src_log)
        assert checkpoint, 'cannot get checkpoint: {}'.format(self.src_log)
        pretrained_saver.restore(self.sess, checkpoint.model_checkpoint_path)
示例#4
0
    def _prepare_training(self):
        """Prepare Training

        Make tensorflow's graph.
        To support Multi-GPU, divide mini-batch.
        And this program has resume function.
        If there is checkpoint file in FLAGS.gan_dir/log, load checkpoint file and restart training.
        """
        assert FLAGS.batch_size >= FLAGS.style_ids_n, 'batch_size must be greater equal than style_ids_n'
        self.gpu_n = len(FLAGS.gpu_ids.split(','))
        self.embedding_chars = set_chars_type(FLAGS.chars_type)
        assert self.embedding_chars != [], 'embedding_chars is empty'
        self.char_embedding_n = len(self.embedding_chars)
        self.z_size = FLAGS.style_z_size + self.char_embedding_n

        with tf.device('/cpu:0'):
            # Set embeddings from uniform distribution
            style_embedding_np = np.random.uniform(
                -1, 1,
                (FLAGS.style_ids_n, FLAGS.style_z_size)).astype(np.float32)
            with tf.variable_scope('embeddings'):
                self.style_embedding = tf.Variable(style_embedding_np,
                                                   name='style_embedding')

            self.style_ids = tf.placeholder(tf.int32, (FLAGS.batch_size, ),
                                            name='style_ids')
            self.char_ids = tf.placeholder(tf.int32, (FLAGS.batch_size, ),
                                           name='char_ids')
            self.is_train = tf.placeholder(tf.bool, name='is_train')
            self.real_imgs = tf.placeholder(tf.float32,
                                            (FLAGS.batch_size, FLAGS.img_width,
                                             FLAGS.img_height, FLAGS.img_dim),
                                            name='real_imgs')
            self.labels = tf.placeholder(
                tf.float32, (FLAGS.batch_size, self.char_embedding_n),
                name='labels')

            d_opt = tf.train.AdamOptimizer(learning_rate=0.0001,
                                           beta1=0.,
                                           beta2=0.9)
            g_opt = tf.train.AdamOptimizer(learning_rate=0.0001,
                                           beta1=0.,
                                           beta2=0.9)

        # Initialize lists for multi gpu
        fake_imgs = [0] * self.gpu_n
        d_loss = [0] * self.gpu_n
        g_loss = [0] * self.gpu_n

        d_grads = [0] * self.gpu_n
        g_grads = [0] * self.gpu_n

        divided_batch_size = FLAGS.batch_size // self.gpu_n
        is_not_first = False

        # Build graph
        for i in range(self.gpu_n):
            batch_start = i * divided_batch_size
            batch_end = (i + 1) * divided_batch_size
            with tf.device('/gpu:{}'.format(i)):
                if FLAGS.arch == 'DCGAN':
                    generator = GeneratorDCGAN(img_size=(FLAGS.img_width,
                                                         FLAGS.img_height),
                                               img_dim=FLAGS.img_dim,
                                               z_size=self.z_size,
                                               layer_n=4,
                                               k_size=3,
                                               smallest_hidden_unit_n=64,
                                               is_bn=False)
                    discriminator = DiscriminatorDCGAN(
                        img_size=(FLAGS.img_width, FLAGS.img_height),
                        img_dim=FLAGS.img_dim,
                        layer_n=4,
                        k_size=3,
                        smallest_hidden_unit_n=64,
                        is_bn=False)
                elif FLAGS.arch == 'ResNet':
                    generator = GeneratorResNet(k_size=3, smallest_unit_n=64)
                    discriminator = DiscriminatorResNet(k_size=3,
                                                        smallest_unit_n=64)

                # If sum of (style/char)_ids is less than -1, z is generated from uniform distribution
                style_z = tf.cond(
                    tf.less(
                        tf.reduce_sum(self.style_ids[batch_start:batch_end]),
                        0), lambda: tf.random_uniform(
                            (divided_batch_size, FLAGS.style_z_size), -1, 1),
                    lambda: tf.nn.embedding_lookup(
                        self.style_embedding, self.style_ids[batch_start:
                                                             batch_end]))
                char_z = tf.one_hot(self.char_ids[batch_start:batch_end],
                                    self.char_embedding_n)
                z = tf.concat([style_z, char_z], axis=1)

                # Generate fake images
                fake_imgs[i] = generator(z,
                                         is_reuse=is_not_first,
                                         is_train=self.is_train)

                # Calculate loss
                d_real = discriminator(self.real_imgs[batch_start:batch_end],
                                       is_reuse=is_not_first,
                                       is_train=self.is_train)
                d_fake = discriminator(fake_imgs[i],
                                       is_reuse=True,
                                       is_train=self.is_train)
                d_loss[i] = -(tf.reduce_mean(d_real) - tf.reduce_mean(d_fake))
                g_loss[i] = -tf.reduce_mean(d_fake)

                # Calculate gradient Penalty
                epsilon = tf.random_uniform((divided_batch_size, 1, 1, 1),
                                            minval=0.,
                                            maxval=1.)
                interp = self.real_imgs[batch_start:batch_end] + epsilon * (
                    fake_imgs[i] - self.real_imgs[batch_start:batch_end])
                d_interp = discriminator(interp,
                                         is_reuse=True,
                                         is_train=self.is_train)
                grads = tf.gradients(d_interp, [interp])[0]
                slopes = tf.sqrt(
                    tf.reduce_sum(tf.square(grads), reduction_indices=[-1]))
                grad_penalty = tf.reduce_mean((slopes - 1.)**2)
                d_loss[i] += 10 * grad_penalty

                # Get trainable variables
                d_vars = [
                    var for var in tf.trainable_variables()
                    if 'discriminator' in var.name
                ]
                g_vars = [
                    var for var in tf.trainable_variables()
                    if 'generator' in var.name
                ]

                d_grads[i] = d_opt.compute_gradients(d_loss[i],
                                                     var_list=d_vars)
                g_grads[i] = g_opt.compute_gradients(g_loss[i],
                                                     var_list=g_vars)

            is_not_first = True

        with tf.device('/cpu:0'):
            self.fake_imgs = tf.concat(fake_imgs, axis=0)
            avg_d_grads = average_gradients(d_grads)
            avg_g_grads = average_gradients(g_grads)
            self.d_train = d_opt.apply_gradients(avg_d_grads)
            self.g_train = g_opt.apply_gradients(avg_g_grads)

        # Calculate summary for tensorboard
        tf.summary.scalar('d_loss', -(sum(d_loss) / len(d_loss)))
        tf.summary.scalar('g_loss', -(sum(g_loss) / len(g_loss)))
        self.summary = tf.summary.merge_all()

        # Setup session
        sess_config = tf.ConfigProto(gpu_options=tf.GPUOptions(
            visible_device_list=FLAGS.gpu_ids))
        self.sess = tf.Session(config=sess_config)
        self.saver = tf.train.Saver(max_to_keep=5)

        # If checkpoint is found, restart training
        checkpoint = tf.train.get_checkpoint_state(self.dst_log)
        if checkpoint:
            saver_resume = tf.train.Saver()
            saver_resume.restore(self.sess, checkpoint.model_checkpoint_path)
            self.epoch_start = int(
                checkpoint.model_checkpoint_path.split('-')[-1])
            print('restore ckpt')
        else:
            self.sess.run(tf.global_variables_initializer())
            self.epoch_start = 0

        # Setup writer for tensorboard
        self.writer = tf.summary.FileWriter(self.dst_log)
parser.add_argument(
    "--n_cpu",
    type=int,
    default=8,
    help="number of cpu threads to use during batch generation")
opt = parser.parse_args()

SCALE_FACTOR = opt.scale_factor
MODEL_NAME = opt.model_name
hr_shape = (opt.hr_height, opt.hr_width)

results = {'Test': {'psnr': [], 'ssim': []}}

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

generator = GeneratorResNet()
generator = nn.DataParallel(generator, device_ids=[0, 1, 2])
generator.to(device)

# generator.load_state_dict(torch.load("saved_models/generator_%d_%d.pth" % (4,99)))
generator.load_state_dict(torch.load("saved_models/" + MODEL_NAME))
generator.eval()

test_dataloader = DataLoader(
    TestImageDataset("../My_dataset/single_channel_100000/%s" %
                     opt.test_dataset_name,
                     hr_shape=hr_shape,
                     scale_factor=opt.scale_factor),  # change
    batch_size=1,
    shuffle=False,
    num_workers=opt.n_cpu,
示例#6
0
def print_network(model, name):
    """
    Print out the network information
    https://github.com/yunjey/stargan/blob/master/solver.py
    """
    num_params = 0
    for p in model.parameters():
        num_params += p.numel()
    print(model)
    print(name)
    print("The number of parameters: {}".format(num_params))


# Initialize generator and discriminator
generator = GeneratorResNet(input_shape=img_shape, residual_blocks=opt.residual_blocks, c_dim=c_dim)
discriminator = Discriminator(input_shape=img_shape, c_dim=c_dim)

if cuda:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    criterion_cycle.cuda()

if opt.is_print:
    print_network(generator, 'Generator')
    print_network(discriminator, 'Discriminator')

# Tensor type
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

if opt.epoch != 0:
    def __init__(self, opt):
        self.config = opt

        # Make output dirs
        os.makedirs('saved_models/%s' % (opt.model_name), exist_ok=True)
        os.makedirs('images/%s' % (opt.model_name), exist_ok=True)

        self.cuda = opt.gpu_id > -1

        # Gs and Ds
        self.G_AB = GeneratorResNet(res_blocks=opt.n_residual_blocks)
        self.G_BA = GeneratorResNet(res_blocks=opt.n_residual_blocks)
        if opt.large_patch:
            self.D_A = LargePatchDiscriminator()
            self.D_B = LargePatchDiscriminator()
        else:
            self.D_A = Discriminator()
            self.D_B = Discriminator()

        # Patch
        if opt.large_patch:
            self.patch = (1, 64, 64)
        else:
            self.patch = (1, 16, 16)

        # Weight init
        self.G_AB.apply(weights_init_normal)
        self.G_BA.apply(weights_init_normal)
        self.D_A.apply(weights_init_normal)
        self.D_B.apply(weights_init_normal)

        # Loss
        self.criterion_GAN = torch.nn.MSELoss()
        self.criterion_cycle = torch.nn.L1Loss()
        self.criterion_identity = torch.nn.L1Loss()

        if self.cuda:
            self.G_AB = self.G_AB.cuda()
            self.G_BA = self.G_BA.cuda()
            self.D_A = self.D_A.cuda()
            self.D_B = self.D_B.cuda()
            self.criterion_GAN = self.criterion_GAN.cuda()
            self.criterion_cycle = self.criterion_cycle.cuda()
            self.criterion_identity = self.criterion_identity.cuda()

        # Optimizers
        self.optimizer_G = torch.optim.Adam(itertools.chain(
            self.G_AB.parameters(), self.G_BA.parameters()),
                                            lr=opt.lr,
                                            betas=(opt.b1, opt.b2))
        self.optimizer_D = torch.optim.Adam(self.D_A.parameters(),
                                            lr=opt.lr,
                                            betas=(opt.b1, opt.b2))
        # Learning rate update schedulers
        self.lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_G,
            lr_lambda=LambdaLR(opt.n_epochs, 0, opt.decay_epoch).step)
        self.lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_D,
            lr_lambda=LambdaLR(opt.n_epochs, 0, opt.decay_epoch).step)
        self.Tensor = torch.cuda.FloatTensor if self.cuda else torch.Tensor
        # Loss weights
        self.lambda_cyc = 10
        self.lambda_id = opt.lambda_id * self.lambda_cyc

        # Buffers of previously generated samples
        self.fake_A_buffer = ReplayBuffer()
        self.fake_B_buffer = ReplayBuffer()

        # Image transformations
        A_transforms_ = [
            transforms.CenterCrop((178, 178)),
            transforms.Resize((300, 300)),
            transforms.RandomCrop((256, 256)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomAffine(degrees=opt.rotate_degree,
                                    fillcolor=(255, 255, 255)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]

        B_transforms_ = [
            transforms.Resize((360, 360)),
            transforms.RandomCrop((256, 256)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomAffine(degrees=opt.rotate_degree,
                                    fillcolor=(255, 255, 255)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]
        # Training data loader
        self.train_dataloader = DataLoader(
            ImageDataset("./data/",
                         A_transforms_=A_transforms_,
                         B_transforms_=B_transforms_),
            batch_size=1,
            shuffle=True,
        )
        # Test data loader
        self.val_dataloader = DataLoader(ImageDataset(
            "./data/",
            A_transforms_=A_transforms_,
            B_transforms_=B_transforms_,
            mode='test'),
                                         batch_size=1)
class Trainer():
    def __init__(self, opt):
        self.config = opt

        # Make output dirs
        os.makedirs('saved_models/%s' % (opt.model_name), exist_ok=True)
        os.makedirs('images/%s' % (opt.model_name), exist_ok=True)

        self.cuda = opt.gpu_id > -1

        # Gs and Ds
        self.G_AB = GeneratorResNet(res_blocks=opt.n_residual_blocks)
        self.G_BA = GeneratorResNet(res_blocks=opt.n_residual_blocks)
        if opt.large_patch:
            self.D_A = LargePatchDiscriminator()
            self.D_B = LargePatchDiscriminator()
        else:
            self.D_A = Discriminator()
            self.D_B = Discriminator()

        # Patch
        if opt.large_patch:
            self.patch = (1, 64, 64)
        else:
            self.patch = (1, 16, 16)

        # Weight init
        self.G_AB.apply(weights_init_normal)
        self.G_BA.apply(weights_init_normal)
        self.D_A.apply(weights_init_normal)
        self.D_B.apply(weights_init_normal)

        # Loss
        self.criterion_GAN = torch.nn.MSELoss()
        self.criterion_cycle = torch.nn.L1Loss()
        self.criterion_identity = torch.nn.L1Loss()

        if self.cuda:
            self.G_AB = self.G_AB.cuda()
            self.G_BA = self.G_BA.cuda()
            self.D_A = self.D_A.cuda()
            self.D_B = self.D_B.cuda()
            self.criterion_GAN = self.criterion_GAN.cuda()
            self.criterion_cycle = self.criterion_cycle.cuda()
            self.criterion_identity = self.criterion_identity.cuda()

        # Optimizers
        self.optimizer_G = torch.optim.Adam(itertools.chain(
            self.G_AB.parameters(), self.G_BA.parameters()),
                                            lr=opt.lr,
                                            betas=(opt.b1, opt.b2))
        self.optimizer_D = torch.optim.Adam(self.D_A.parameters(),
                                            lr=opt.lr,
                                            betas=(opt.b1, opt.b2))
        # Learning rate update schedulers
        self.lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_G,
            lr_lambda=LambdaLR(opt.n_epochs, 0, opt.decay_epoch).step)
        self.lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_D,
            lr_lambda=LambdaLR(opt.n_epochs, 0, opt.decay_epoch).step)
        self.Tensor = torch.cuda.FloatTensor if self.cuda else torch.Tensor
        # Loss weights
        self.lambda_cyc = 10
        self.lambda_id = opt.lambda_id * self.lambda_cyc

        # Buffers of previously generated samples
        self.fake_A_buffer = ReplayBuffer()
        self.fake_B_buffer = ReplayBuffer()

        # Image transformations
        A_transforms_ = [
            transforms.CenterCrop((178, 178)),
            transforms.Resize((300, 300)),
            transforms.RandomCrop((256, 256)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomAffine(degrees=opt.rotate_degree,
                                    fillcolor=(255, 255, 255)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]

        B_transforms_ = [
            transforms.Resize((360, 360)),
            transforms.RandomCrop((256, 256)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomAffine(degrees=opt.rotate_degree,
                                    fillcolor=(255, 255, 255)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]
        # Training data loader
        self.train_dataloader = DataLoader(
            ImageDataset("./data/",
                         A_transforms_=A_transforms_,
                         B_transforms_=B_transforms_),
            batch_size=1,
            shuffle=True,
        )
        # Test data loader
        self.val_dataloader = DataLoader(ImageDataset(
            "./data/",
            A_transforms_=A_transforms_,
            B_transforms_=B_transforms_,
            mode='test'),
                                         batch_size=1)

    def train_epoch(self, epoch):
        prev_time = time.time()
        for i, batch in enumerate(self.train_dataloader):

            # Model input
            real_A = Variable(batch['A'].type(self.Tensor))
            real_B = Variable(batch['B'].type(self.Tensor))

            # Adversarial ground truths

            valid = Variable(self.Tensor(np.ones(
                (real_A.size(0), *self.patch))),
                             requires_grad=False)
            fake = Variable(self.Tensor(np.zeros(
                (real_A.size(0), *self.patch))),
                            requires_grad=False)

            #  Train Generators

            self.optimizer_G.zero_grad()

            # GAN loss
            fake_B = self.G_AB(real_A)
            loss_GAN_AB = self.criterion_GAN(self.D_B(fake_B), valid)
            fake_A = self.G_BA(real_B)
            loss_GAN_BA = self.criterion_GAN(self.D_A(fake_A), valid)

            loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

            # Cycle loss
            recov_A = self.G_BA(fake_B)
            loss_cycle_A = self.criterion_cycle(recov_A, real_A)
            recov_B = self.G_AB(fake_A)
            loss_cycle_B = self.criterion_cycle(recov_B, real_B)

            loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

            # Identity loss

            loss_id_A = self.criterion_identity(self.G_BA(real_A), real_A)
            loss_id_B = self.criterion_identity(self.G_AB(real_B), real_B)
            loss_identity = (loss_id_A + loss_id_B) / 2

            # Total loss
            loss_G = loss_GAN + self.lambda_cyc * loss_cycle + self.lambda_id * loss_identity
            loss_G.backward()
            self.optimizer_G.step()

            #  Train Discriminator

            self.optimizer_D.zero_grad()

            # Real loss
            loss_real = self.criterion_GAN(self.D_A(real_A), valid)
            # Fake loss (on batch of previously generated samples)
            fake_A_ = self.fake_A_buffer.push_and_pop(fake_A)
            loss_fake = self.criterion_GAN(self.D_A(fake_A_.detach()), fake)
            # Total loss
            loss_D_A = (loss_real + loss_fake) / 2

            loss_D_A.backward()
            self.optimizer_D.step()

            self.optimizer_D.zero_grad()
            loss_real = self.criterion_GAN(self.D_B(real_B), valid)
            fake_B_ = self.fake_B_buffer.push_and_pop(fake_B)
            loss_fake = self.criterion_GAN(self.D_B(fake_B_.detach()), fake)
            loss_D_B = (loss_real + loss_fake) / 2

            loss_D_B.backward()
            self.optimizer_D.step()

            loss_D = (loss_D_A + loss_D_B) / 2

            # Determine approximate time left
            batches_done = epoch * len(self.train_dataloader) + i
            batches_left = self.config.n_epochs * len(
                self.train_dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            sys.stdout.write(
                "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s"
                % (epoch, self.config.n_epochs, i, len(self.train_dataloader),
                   loss_D.item(), loss_G.item(), loss_GAN.item(),
                   loss_cycle.item(), loss_identity.item(), time_left))

            if batches_done % self.config.sample_interval == 0:
                # Sample a picture
                imgs = next(iter(self.val_dataloader))
                real_A = Variable(imgs['A'].type(self.Tensor))
                fake_B = self.G_AB(real_A)
                real_B = Variable(imgs['B'].type(self.Tensor))
                fake_A = self.G_BA(real_B)
                img_sample = torch.cat(
                    (real_A.data, fake_B.data, real_B.data, fake_A.data), 0)
                save_image(img_sample,
                           'images/%s/%s.png' %
                           (self.config.model_name, batches_done),
                           nrow=4,
                           normalize=True)

        self.lr_scheduler_G.step()
        self.lr_scheduler_D.step()

        if self.config.checkpoint_interval != -1 and epoch % self.config.checkpoint_interval == 0:
            torch.save(
                self.G_AB.state_dict(), 'saved_models/%s/G_AB_%d.pth' %
                (self.config.model_name, epoch))
示例#9
0
import torch
from models import GeneratorResNet

parser = argparse.ArgumentParser()
parser.add_argument('--check_point', type=str, default='saved_models/G_AB_10.pth',
                    help='check point from which load trained model')
parser.add_argument('--batch_size', type=int, default=1, help='size of the batches')
parser.add_argument('--A_file', type=str, default='test.png', help='path of the data')
parser.add_argument('--img_height', type=int, default=256, help='size of image height')
parser.add_argument('--img_width', type=int, default=256, help='size of image width')
parser.add_argument('--gpu_id', type=int, default=-1, help='GPU id')
opt = parser.parse_args()
cuda = opt.gpu_id > -1

# # Load pretrained model G_AB
G_AB = GeneratorResNet()
if cuda:
    G_AB = G_AB.cuda()
G_AB.load_state_dict(torch.load(opt.check_point))

Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

# Image transformations
transforms_ = [transforms.Resize((opt.img_height, opt.img_width)),
               transforms.ToTensor(),
               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
img_transformer = transforms.Compose(transforms_)

# Test data

img = img_transformer(Image.open(opt.A_file).convert("RGB"))
示例#10
0
def main():
    cuda = torch.cuda.is_available()

    input_shape = (opt.channels, opt.img_height, opt.img_width)

    # Initialize generator and discriminator
    G_AB = GeneratorResNet(input_shape, opt.n_residual_blocks)
    G_BA = GeneratorResNet(input_shape, opt.n_residual_blocks)
    D_A = Discriminator(input_shape)
    D_B = Discriminator(input_shape)

    if cuda:
        G_AB = G_AB.cuda()
        G_BA = G_BA.cuda()
        D_A = D_A.cuda()
        D_B = D_B.cuda()
        criterion_GAN.cuda()
        criterion_cycle.cuda()
        criterion_identity.cuda()

    if opt.epoch != 0:
        # Load pretrained models
        G_AB.load_state_dict(
            torch.load("saved_models/%s/G_AB_%d.pth" %
                       (opt.dataset_name, opt.epoch)))
        G_BA.load_state_dict(
            torch.load("saved_models/%s/G_BA_%d.pth" %
                       (opt.dataset_name, opt.epoch)))
        D_A.load_state_dict(
            torch.load("saved_models/%s/D_A_%d.pth" %
                       (opt.dataset_name, opt.epoch)))
        D_B.load_state_dict(
            torch.load("saved_models/%s/D_B_%d.pth" %
                       (opt.dataset_name, opt.epoch)))
    else:
        # Initialize weights
        G_AB.apply(weights_init_normal)
        G_BA.apply(weights_init_normal)
        D_A.apply(weights_init_normal)
        D_B.apply(weights_init_normal)

    # Optimizers
    optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(),
                                                   G_BA.parameters()),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))
    optimizer_D_A = torch.optim.Adam(D_A.parameters(),
                                     lr=opt.lr,
                                     betas=(opt.b1, opt.b2))
    optimizer_D_B = torch.optim.Adam(D_B.parameters(),
                                     lr=opt.lr,
                                     betas=(opt.b1, opt.b2))

    # Learning rate update schedulers
    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
        optimizer_G,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_A,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_B,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)

    Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

    # Buffers of previously generated samples
    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()

    # Image transformations
    transforms_ = [
        transforms.Resize(int(opt.img_height * 1.12), Image.BICUBIC),
        transforms.RandomCrop((opt.img_height, opt.img_width)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]

    # Training data loader
    dataloader = DataLoader(
        ImageDataset("../../data/%s" % opt.dataset_name,
                     transforms_=transforms_,
                     unaligned=True),
        batch_size=opt.batch_size,
        shuffle=True,
        num_workers=opt.n_cpu,
    )
    # Test data loader
    val_dataloader = DataLoader(
        ImageDataset("../../data/%s" % opt.dataset_name,
                     transforms_=transforms_,
                     unaligned=True,
                     mode="test"),
        batch_size=5,
        shuffle=True,
        num_workers=1,
    )

    def sample_images(batches_done):
        """Saves a generated sample from the test set"""
        imgs = next(iter(val_dataloader))
        G_AB.eval()
        G_BA.eval()
        real_A = Variable(imgs["A"].type(Tensor))
        fake_B = G_AB(real_A)
        real_B = Variable(imgs["B"].type(Tensor))
        fake_A = G_BA(real_B)
        # Arange images along x-axis
        real_A = make_grid(real_A, nrow=5, normalize=True)
        real_B = make_grid(real_B, nrow=5, normalize=True)
        fake_A = make_grid(fake_A, nrow=5, normalize=True)
        fake_B = make_grid(fake_B, nrow=5, normalize=True)
        # Arange images along y-axis
        image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)
        save_image(image_grid,
                   "images/%s/%s.png" % (opt.dataset_name, batches_done),
                   normalize=False)

    # ----------
    #  Training
    # ----------
    prev_time = time.time()
    for epoch in range(opt.epoch, opt.n_epochs):
        for i, batch in enumerate(dataloader):

            # Set model input
            real_A = Variable(batch["A"].type(Tensor))
            real_B = Variable(batch["B"].type(Tensor))

            # Adversarial ground truths
            valid = Variable(Tensor(
                np.ones((real_A.size(0), *D_A.output_shape))),
                             requires_grad=False)
            fake = Variable(Tensor(
                np.zeros((real_A.size(0), *D_A.output_shape))),
                            requires_grad=False)

            # ------------------
            #  Train Generators
            # ------------------

            G_AB.train()
            G_BA.train()

            optimizer_G.zero_grad()

            # Identity loss
            loss_id_A = criterion_identity(G_BA(real_A), real_A)
            loss_id_B = criterion_identity(G_AB(real_B), real_B)

            loss_identity = (loss_id_A + loss_id_B) / 2

            # GAN loss
            fake_B = G_AB(real_A)
            loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
            fake_A = G_BA(real_B)
            loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)

            loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

            # Cycle loss
            recov_A = G_BA(fake_B)
            loss_cycle_A = criterion_cycle(recov_A, real_A)
            recov_B = G_AB(fake_A)
            loss_cycle_B = criterion_cycle(recov_B, real_B)

            loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

            # Total loss
            loss_G = loss_GAN + opt.lambda_cyc * loss_cycle + opt.lambda_id * loss_identity

            loss_G.backward()
            optimizer_G.step()

            # -----------------------
            #  Train Discriminator A
            # -----------------------

            optimizer_D_A.zero_grad()

            # Real loss
            loss_real = criterion_GAN(D_A(real_A), valid)
            # Fake loss (on batch of previously generated samples)
            fake_A_ = fake_A_buffer.push_and_pop(fake_A)
            loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
            # Total loss
            loss_D_A = (loss_real + loss_fake) / 2

            loss_D_A.backward()
            optimizer_D_A.step()

            # -----------------------
            #  Train Discriminator B
            # -----------------------

            optimizer_D_B.zero_grad()

            # Real loss
            loss_real = criterion_GAN(D_B(real_B), valid)
            # Fake loss (on batch of previously generated samples)
            fake_B_ = fake_B_buffer.push_and_pop(fake_B)
            loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
            # Total loss
            loss_D_B = (loss_real + loss_fake) / 2

            loss_D_B.backward()
            optimizer_D_B.step()

            loss_D = (loss_D_A + loss_D_B) / 2

            # --------------
            #  Log Progress
            # --------------

            # Determine approximate time left
            batches_done = epoch * len(dataloader) + i
            batches_left = opt.n_epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            sys.stdout.write(
                "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s"
                % (
                    epoch,
                    opt.n_epochs,
                    i,
                    len(dataloader),
                    loss_D.item(),
                    loss_G.item(),
                    loss_GAN.item(),
                    loss_cycle.item(),
                    loss_identity.item(),
                    time_left,
                ))

            # If at sample interval save image
            if batches_done % opt.sample_interval == 0:
                sample_images(batches_done)

        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()

        if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
            # Save model checkpoints
            torch.save(
                G_AB.state_dict(),
                "saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, epoch))
            torch.save(
                G_BA.state_dict(),
                "saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, epoch))
            torch.save(
                D_A.state_dict(),
                "saved_models/%s/D_A_%d.pth" % (opt.dataset_name, epoch))
            torch.save(
                D_B.state_dict(),
                "saved_models/%s/D_B_%d.pth" % (opt.dataset_name, epoch))
示例#11
0
                    help='rotate degree')
parser.add_argument('--lambda_id', type=float, default=0.5, help='lambda_id')
parser.add_argument('--large_patch',
                    type=bool,
                    default=False,
                    help='whether use large patch')
opt = parser.parse_args()

# make output dirs
os.makedirs('saved_models/%s' % (opt.model_name), exist_ok=True)
os.makedirs('images/%s' % (opt.model_name), exist_ok=True)

cuda = opt.gpu_id > -1

# Gs and Ds
G_AB = GeneratorResNet(res_blocks=opt.n_residual_blocks)
G_BA = GeneratorResNet(res_blocks=opt.n_residual_blocks)
if opt.large_patch:
    D_A = LargePatchDiscriminator()
    D_B = LargePatchDiscriminator()
else:
    D_A = Discriminator()
    D_B = Discriminator()

criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

if cuda:
    G_AB = G_AB.cuda()
    G_BA = G_BA.cuda()
 def _define_generator(self, path_to_model):
     gen = GeneratorResNet()
     gen.load_state_dict(torch.load(path_to_model))
     return gen
示例#13
0
    def _prepare_generating(self):
        """Prepare generating

        Make tensorflow's graph.
        """
        self.z_size = self.style_z_size + self.char_embedding_n

        if self.arch == 'DCGAN':
            generator = GeneratorDCGAN(img_size=(self.img_width,
                                                 self.img_height),
                                       img_dim=self.img_dim,
                                       z_size=self.z_size,
                                       layer_n=4,
                                       k_size=3,
                                       smallest_hidden_unit_n=64,
                                       is_bn=False)
        elif self.arch == 'ResNet':
            generator = GeneratorResNet(k_size=3, smallest_unit_n=64)

        if FLAGS.generate_walk:
            style_embedding_np = np.random.uniform(
                -1, 1, (FLAGS.char_img_n // self.walk_step,
                        self.style_z_size)).astype(np.float32)
        else:
            style_embedding_np = np.random.uniform(
                -1, 1,
                (self.style_ids_n, self.style_z_size)).astype(np.float32)

        with tf.variable_scope('embeddings'):
            style_embedding = tf.Variable(style_embedding_np,
                                          name='style_embedding')

        self.char_ids = tf.placeholder(tf.int32, (self.batch_size, ),
                                       name='char_ids')

        style_z = self.latent
        char_z = tf.one_hot(self.char_ids, self.char_embedding_n)

        z = tf.concat([style_z, char_z], axis=1)

        self.generated_imgs = generator(z, is_train=False)

        if FLAGS.gpu_ids == "":
            sess_config = tf.ConfigProto(device_count={"GPU": 0},
                                         log_device_placement=True)
        else:
            sess_config = tf.ConfigProto(gpu_options=tf.GPUOptions(
                visible_device_list=FLAGS.gpu_ids))
        self.sess = tf.Session(config=sess_config)
        self.sess.run(tf.global_variables_initializer())

        if FLAGS.generate_walk:
            var_list = [
                var for var in tf.global_variables()
                if 'embedding' not in var.name
            ]
        else:
            var_list = [var for var in tf.global_variables()]

        pretrained_saver = tf.train.Saver(var_list=var_list)

        ckpt_filename = "result.ckpt-10000"
        ckpt_filepath = os.path.join(self.src_log, ckpt_filename)
        pretrained_saver.restore(self.sess, ckpt_filepath)