Example #1
0
def generate_from_ckpt():
    with tf.device(device):
        noises = tf.random_normal([batch_size, z_dim], mean=0.0, stddev=1.0)
        generate_images = generator(noises)
    session_config = tf.ConfigProto(allow_soft_placement=True,
                                    log_device_placement=False)
    session_config.gpu_options.allow_growth = True
    session_config.gpu_options.per_process_gpu_memory_fraction = 0.6

    if ckpt_dir is not None:
        with tf.Session(config=session_config) as sess:
            saver = tf.train.Saver()
            lasted_checkpoint = tf.train.latest_checkpoint(ckpt_dir)
            if lasted_checkpoint is not None:
                saver.restore(sess, lasted_checkpoint)
                images = sess.run(generate_images)
                save_sample(images, [8, 8], './generated_image.jpg')
                print('generate a image:', './generated_image.jpg')
            else:
                print('there is not checkpoint file in:', ckpt_dir)
Example #2
0
def generate(config_path,
             iteration,
             trunc,
             debug,
             image_n_frames,
             video_n_frames,
             change_modes,
             inversed=False,
             homography_dir=None,
             separate_files=False,
             num_files=1,
             save_frames=False):
    with open(config_path) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    model_name = os.path.basename(config_path)[:-len('.yaml')]

    os.makedirs(constants.LOG_DIR, exist_ok=True)
    setup_logger(out_file=os.path.join(constants.LOG_DIR,
                                       'gen_' + model_name + '.log'),
                 stdout_level=logging.DEBUG if debug else logging.INFO)

    gen_model = get_model(model_name=model_name,
                          config=config,
                          iteration=iteration)
    gen_path = os.path.join(constants.GEN_DIR, model_name)
    os.makedirs(gen_path, exist_ok=True)

    generator = gen_model['g_running'].eval()
    code_size = config.get('code_size', constants.DEFAULT_CODE_SIZE)
    alpha = gen_model['alpha']
    step = gen_model['step']
    resolution = gen_model['resolution']
    iteration = gen_model['iteration']

    for mode in change_modes.split(','):
        assert mode in available_modes, mode
        if mode.startswith('noise'):
            style_change_mode = StyleChangeMode.REPEAT
        else:
            style_change_mode = StyleChangeMode.INTERPOLATE

        if mode == 'style':
            noise_change_mode = NoiseChangeMode.FIXED
        elif mode.endswith('homography'):
            noise_change_mode = NoiseChangeMode.HOMOGRAPHY
            assert homography_dir is not None, 'The homography mode needs a path to a homography directory!'
        else:
            noise_change_mode = NoiseChangeMode.SHIFT
        noise_change_modes = [noise_change_mode] * constants.MAX_LAYERS_NUM

        if mode == 'images':
            save_video = False
            save_images = True
        else:
            save_video = True
            save_images = False

        save_dir = os.path.join(gen_path, mode)
        if mode.endswith('homography'):
            save_dir = os.path.join(save_dir, os.path.basename(homography_dir))

        save_sample(generator,
                    alpha,
                    step,
                    code_size,
                    resolution,
                    save_dir=save_dir,
                    name=('inversed_' if inversed else '') +
                    str(iteration + 1).zfill(6),
                    sample_size=constants.SAMPLE_SIZE,
                    truncation_psi=trunc,
                    images_n_frames=image_n_frames,
                    video_n_frames=video_n_frames,
                    save_images=save_images,
                    save_video=save_video,
                    style_change_mode=style_change_mode,
                    noise_change_modes=noise_change_modes,
                    inversed=inversed,
                    homography_dir=homography_dir,
                    separate_files=separate_files,
                    num_files=num_files,
                    save_frames=save_frames)
Example #3
0
def train():
    opt = parse_args()
    cuda = True if torch.cuda.is_available() else False

    input_shape = (opt.channels, opt.img_width, opt.img_height)
    FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    transform = transforms.Compose([
        transforms.Resize(int(opt.img_height * 1.12), Image.BICUBIC),
        transforms.RandomCrop((opt.img_height, opt.img_width)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # Get dataloader
    train_loader = coco_loader(opt, mode='train', transform=transform)
    test_loader = coco_loader(opt, mode='test', transform=transform)

    # Get vgg
    vgg = VGGNet()

    # Initialize two generators and the discriminator
    shared_E = Encoder(opt.channels, opt.dim, opt.n_downsample)
    shared_D = Decoder(3, 256, opt.n_upsample)

    G_A = GeneratorA(opt.n_residual, 256, shared_E, shared_D)
    G_B = GeneratorB(opt.n_residual, 256, shared_E, shared_D)

    D_B = Discriminator(input_shape)

    # Initialize weights
    G_A.apply(weights_init_normal)
    G_B.apply(weights_init_normal)
    D_B.apply(weights_init_normal)

    # Losses
    criterion_GAN = torch.nn.MSELoss()
    criterion_pixel = torch.nn.L1Loss()

    if cuda:
        vgg = vgg.cuda().eval()
        G_A = G_A.cuda()
        G_B = G_B.cuda()
        D_B = D_B.cuda()
        criterion_GAN.cuda()
        criterion_pixel.cuda()

    optimizer_G = torch.optim.Adam(itertools.chain(G_A.parameters(),
                                                   G_B.parameters()),
                                   lr=opt.lr,
                                   betas=(0.5, 0.999))
    optimizer_D = torch.optim.Adam(D_B.parameters(),
                                   lr=opt.lr,
                                   betas=(0.5, 0.999))

    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
        optimizer_G,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)

    # Compute the style features in advance
    style_img = Variable(load_img(opt.style_img, transform).type(FloatTensor))
    style_feature = vgg(style_img)

    prev_time = time.time()
    for epoch in range(opt.epoch, opt.n_epochs):
        for batch_i, content_img in enumerate(train_loader):
            content_img = Variable(content_img.type(FloatTensor))

            valid = Variable(FloatTensor(
                np.ones((content_img.size(0), *D_B.output_shape))),
                             requires_grad=False)
            fake = Variable(FloatTensor(
                np.zeros((content_img.size(0), *D_B.output_shape))),
                            requires_grad=False)

            # ---------------------
            #  Train Generators
            # ---------------------

            optimizer_G.zero_grad()

            # 生成的图像并没有做反正则化,得保证:内容,风格,生成图,图像预处理的一致性!
            stylized_img = G_A(content_img)

            target_feature = vgg(stylized_img)
            content_feature = vgg(content_img)
            loss_st = opt.lambda_st * vgg.compute_st_loss(
                target_feature, content_feature, style_feature,
                opt.lambda_style)

            reconstructed_img = G_B(stylized_img)
            loss_adv = opt.lambda_adv * criterion_GAN(D_B(reconstructed_img),
                                                      valid)

            loss_G = loss_st + loss_adv
            loss_G.backward(retain_graph=True)
            optimizer_G.step()

            # ----------------------
            #  Train Discriminator
            # ----------------------

            optimizer_D.zero_grad()

            loss_D = criterion_GAN(D_B(content_img), valid) + criterion_GAN(
                D_B(reconstructed_img.detach()), fake)
            loss_D.backward()
            optimizer_D.step()

            # ------------------
            # Log Information
            # ------------------

            batches_done = epoch * len(train_loader) + batch_i
            batches_left = opt.n_epochs * len(train_loader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] ETA: %s"
                % (epoch, opt.n_epochs, batch_i, len(train_loader),
                   loss_D.item(), loss_G.item(), time_left))

            if batches_done % opt.sample_interval == 0:
                save_sample(opt.style_name, test_loader, batches_done, G_A,
                            G_B, FloatTensor)

            if batches_done % opt.checkpoint_interval == 0:
                torch.save(
                    G_A.state_dict(),
                    "checkpoints/%s/G_A_%d.pth" % (opt.style_name, epoch))
                torch.save(
                    G_B.state_dict(),
                    "checkpoints/%s/G_B_%d.pth" % (opt.style_name, epoch))

        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D.step()

    torch.save(G_A.state_dict(),
               "checkpoints/%s/G_A_done.pth" % opt.style_name)
    torch.save(G_B.state_dict(),
               "checkpoints/%s/G_B_done.pth" % opt.style_name)
    print("Training Process has been Done!")
Example #4
0
def train():
    with tf.device(device):
        gen_opt, dis_opt, fake_images, gen_loss_summary, dis_loss_summary = build_graph(
        )
    saver = tf.train.Saver()
    session_config = tf.ConfigProto(allow_soft_placement=True,
                                    log_device_placement=False)
    session_config.gpu_options.allow_growth = True
    session_config.gpu_options.per_process_gpu_memory_fraction = 0.8

    with tf.Session(config=session_config) as sess:
        # add tf.train.start_queue_runners, it's important to start queue for tf.train.shuffle_batch
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
        iter_start = 0
        if load_model:
            lasted_checkpoint = tf.train.latest_checkpoint(ckpt_dir)
            if lasted_checkpoint is not None:
                saver.restore(sess, lasted_checkpoint)
                print('load model:', lasted_checkpoint)
                iter_start = int(
                    lasted_checkpoint.split('/')[-1].split('-')[-1]) + 1
            else:
                print('init global variables')
                sess.run(tf.global_variables_initializer())
        for iter_count in range(iter_start, max_iter):

            # train discriminator
            if iter_count % 100 == 99:
                _, summary = sess.run([dis_opt, dis_loss_summary])
                summary_writer.add_summary(summary, iter_count)
            else:
                sess.run(dis_opt)

            # Run gen_opt twice to make sure that d_loss does not go to zero (different from paper)
            sess.run(gen_opt)
            # train generator
            if iter_count % 100 == 99:
                _, summary = sess.run([gen_opt, gen_loss_summary])
                summary_writer.add_summary(summary, iter_count)
            else:
                sess.run(gen_opt)

            # save sample
            if iter_count % 1000 == 999:
                sample_path = os.path.join(sample_dir, '%d.jpg' % iter_count)
                sample = sess.run(fake_images)
                sample = (sample + 1.0) / 2.0
                save_sample(sample, [4, 4], sample_path)
                print('save sample:', sample_path)

            # save model
            if iter_count % 1000 == 999:
                if not os.path.exists(ckpt_dir):
                    os.mkdir(ckpt_dir)
                ckpt_path = os.path.join(ckpt_dir, "model.ckpt")
                saver.save(sess, ckpt_path, global_step=iter_count)
                print('save ckpt:', ckpt_dir)
        coord.request_stop()
        coord.join(threads)
Example #5
0
    def run(self):
        try:
            # setting variables and constants
            model = self.model
            generator = model.generator.train()
            g_running = model.g_running
            discriminator = model.discriminator
            n_frames_discriminator = model.n_frames_discriminator
            g_optimizer = model.g_optimizer
            d_optimizer = model.d_optimizer
            nfd_optimizer = model.nfd_optimizer
            used_samples = model.used_samples
            step = model.step
            resolution = model.resolution
            iteration = model.iteration

            n_critic = constants.N_CRITIC

            config = self.config
            code_size = config.get('code_size', constants.DEFAULT_CODE_SIZE)
            lr = config.get('lr', constants.LR)
            batch_size = config.get('batch_size', constants.BATCH_SIZE)
            init_size = config.get('init_size', constants.INIT_SIZE)
            n_gen_steps = config.get('n_gen_steps', 1)
            max_size = config['max_size']
            max_iterations = config.get('max_iterations',
                                        constants.MAX_ITERATIONS)
            samples_per_phase = config['samples_per_phase']
            loss_fn = config['loss_fn']

            n_frames_params = config.get('n_frames_params', dict())
            n_frames = n_frames_params.get('n', 1)
            n_frames_loss_coef = n_frames_params.get('loss_coef', 0)
            n_frames_final_freq = n_frames_params.get('final_freq', 0)
            n_frames_decay_duration = n_frames_params.get('decay_duration', 0)
            crop_freq = n_frames_params.get('crop_freq', 0)
            mixing = config.get('mixing', False)

            # getting data
            cur_batch_size = batch_size[resolution]
            images_dataloader = CycleLoader(self.images_dataset,
                                            cur_batch_size, resolution)

            if n_frames_loss_coef > 0:
                n_frames_dataloader = CycleLoader(self.n_frames_dataset,
                                                  cur_batch_size, resolution)
                if crop_freq > 0:
                    n_crops_dataloader = CycleLoader(self.n_crops_dataset,
                                                     cur_batch_size,
                                                     resolution)

            if iteration == 0:
                self.adjust_lr(lr, resolution)

            pbar = tqdm.trange(iteration, max_iterations, initial=iteration)

            requires_grad(generator, False)
            requires_grad(discriminator, True)

            discr_loss_val = 0
            gen_loss_val = 0
            grad_loss_val = 0

            max_step = int(math.log2(max_size)) - 2
            final_progress = False

            for iteration in pbar:
                model.iteration = iteration

                # update alpha, step and resolution
                alpha = min(1, 1 / samples_per_phase * (used_samples + 1))
                if resolution == init_size or final_progress:
                    alpha = 1
                if not final_progress and used_samples > samples_per_phase * 2:
                    LOGGER.debug(f'Used samples: {used_samples}.')
                    used_samples = 0
                    step += 1
                    if step > max_step:
                        step = max_step
                        final_progress = True
                        LOGGER.info('Final progress.')
                    else:
                        alpha = 0
                        LOGGER.info(
                            f'Changing resolution from {resolution} to {resolution * 2}.'
                        )
                    resolution = 4 * 2**step
                    model.step = step
                    model.resolution = resolution
                    model.used_samples = used_samples
                    LOGGER.debug(
                        f'Used samples on saving: {model.used_samples}.')
                    self.save_model(step=step)
                    self.adjust_lr(lr, resolution)

                    # setup loaderts
                    cur_batch_size = batch_size[resolution]
                    images_dataloader = CycleLoader(self.images_dataset,
                                                    cur_batch_size, resolution)
                    if n_frames_loss_coef > 0:
                        n_frames_dataloader = CycleLoader(
                            self.n_frames_dataset, cur_batch_size, resolution)
                        if crop_freq > 0:
                            n_crops_dataloader = CycleLoader(
                                self.n_crops_dataset, cur_batch_size,
                                resolution)

                # decide if need to use n_frames on this iteration
                if final_progress or n_frames_decay_duration == 0:
                    n_frames_freq = n_frames_final_freq
                else:
                    n_frames_freq = 0.5 - min(1, used_samples / n_frames_decay_duration) *\
                        (0.5 - n_frames_final_freq)
                n_frames_iteration = True if random.random(
                ) < n_frames_freq else False
                if n_frames_iteration:
                    cur_discr = n_frames_discriminator
                    cur_dataloader = n_frames_dataloader
                    cur_n_frames = n_frames
                    cur_d_optimizer = nfd_optimizer
                else:
                    cur_discr = discriminator
                    cur_dataloader = images_dataloader
                    cur_n_frames = 1
                    cur_d_optimizer = d_optimizer

                cur_discr.zero_grad()
                real_image = next(cur_dataloader)
                LOGGER.debug(f'n_frames iteration: {n_frames_iteration}')
                LOGGER.debug(f'cur_discr: {type(cur_discr.module)}')
                LOGGER.debug(
                    f'real_image shape {real_image.shape}; resolution {resolution}'
                )

                # discriminator step
                real_predict, real_grad_loss_val = discr_backward_real(
                    cur_discr, loss_fn, real_image, step, alpha)
                if mixing and random.random() < 0.9:
                    num_latents = 2
                else:
                    num_latents = 1
                LOGGER.debug(f'Batch size: {cur_batch_size}')
                latents = get_latents(cur_batch_size, code_size,
                                      2 * num_latents)
                gen_in1 = latents[:num_latents]
                gen_in2 = latents[num_latents:]
                LOGGER.debug(f'Latents shape: {gen_in1[0].shape}')
                fake_image = generator(gen_in1,
                                       step=step,
                                       alpha=alpha,
                                       n_frames=cur_n_frames)

                crop_iteration = False
                if n_frames_iteration:
                    if random.random() < crop_freq:
                        crop_iteration = True
                        fake_image = next(n_crops_dataloader)
                discr_loss_val, fake_grad_loss_val = discr_backward_fake(
                    cur_discr, loss_fn, fake_image, real_image, real_predict,
                    step, alpha, False)
                grad_loss_val = real_grad_loss_val or fake_grad_loss_val
                cur_d_optimizer.step()

                # generator step
                if (iteration + 1) % n_critic == 0:
                    for gen_step in range(n_gen_steps):
                        generator.zero_grad()

                        requires_grad(generator, True)
                        requires_grad(cur_discr, False)

                        fake_image = generator(gen_in2,
                                               step=step,
                                               alpha=alpha,
                                               n_frames=cur_n_frames)
                        LOGGER.debug(
                            f'fake image shape when gen {fake_image.shape}')

                        predict = cur_discr(fake_image, step=step, alpha=alpha)
                        if loss_fn == 'wgan-gp':
                            loss = -predict.mean()
                        elif loss_fn == 'r1':
                            loss = F.softplus(-predict).mean()

                        if n_frames_iteration:
                            loss *= n_frames_loss_coef
                        gen_loss_val = loss.item()

                        loss.backward()
                        g_optimizer.step()
                        LOGGER.debug('generator optimizer step')
                        accumulate(to_model=g_running,
                                   from_model=generator.module)

                        requires_grad(generator, False)
                        requires_grad(cur_discr, True)

                used_samples += real_image.shape[0]
                model.used_samples = used_samples

                if (iteration + 1) % constants.SAMPLE_FREQUENCY == 0:
                    LOGGER.info(
                        f'Saving samples on {iteration + 1} iteration.')
                    save_sample(generator=g_running,
                                alpha=alpha,
                                step=step,
                                code_size=code_size,
                                resolution=resolution,
                                save_dir=os.path.join(self.sample_dir),
                                name=f'{str(iteration + 1).zfill(6)}',
                                sample_size=constants.SAMPLE_SIZE,
                                images_n_frames=n_frames,
                                video_n_frames=32)

                if (iteration + 1) % constants.SAVE_FREQUENCY == 0:
                    self.save_model(iteration=iteration + 1)

                if n_frames_iteration:
                    prefix = 'NF'
                    suffix = 'n_frames'
                else:
                    prefix = ''
                    suffix = 'loss'

                state_msg = f'Size: {resolution}; {prefix}G: {gen_loss_val:.3f}; {prefix}D: {discr_loss_val:.3f}; ' +\
                            f'{prefix}Grad: {grad_loss_val:.3f}; Alpha: {alpha:.5f}'
                pbar.set_description(state_msg)

                if iteration % constants.LOG_LOSS_FREQUENCY == 0:
                    self.summary_writer.add_scalar('size', resolution,
                                                   iteration)
                    self.summary_writer.add_scalar(f'G/{suffix}', gen_loss_val,
                                                   iteration)
                    self.summary_writer.add_scalar(f'D/{suffix}',
                                                   discr_loss_val, iteration)
                    self.summary_writer.add_scalar(f'Grad/{suffix}',
                                                   grad_loss_val, iteration)
                    self.summary_writer.add_scalar('alpha', alpha, iteration)
                    if n_frames_iteration and crop_freq > 0:
                        if crop_iteration:
                            suffix = 'crop'
                        else:
                            suffix = 'no_crop'
                        self.summary_writer.add_scalar(f'D/{suffix}',
                                                       discr_loss_val,
                                                       iteration)

        except KeyboardInterrupt:
            LOGGER.warning('Interrupted by user')
            self.save_model(iteration=iteration)
Example #6
0
def train():
    opt = parse_args()

    os.makedirs("images/%s" % (opt.dataset), exist_ok=True)
    os.makedirs("checkpoints/%s" % (opt.dataset), exist_ok=True)

    cuda = True if torch.cuda.is_available() else False
    FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    # get dataloader
    train_loader = commic2human_loader(opt, mode='train')
    test_loader = commic2human_loader(opt, mode='test')

    # Dimensionality
    input_shape = (opt.channels, opt.img_height, opt.img_width)
    shared_dim = opt.dim * (2**opt.n_downsample)

    # Initialize generator and discriminator
    shared_E = ResidualBlock(in_channels=shared_dim)
    E1 = Encoder(dim=opt.dim,
                 n_downsample=opt.n_downsample,
                 shared_block=shared_E)
    E2 = Encoder(dim=opt.dim,
                 n_downsample=opt.n_downsample,
                 shared_block=shared_E)

    shared_G = ResidualBlock(in_channels=shared_dim)
    G1 = Generator(dim=opt.dim,
                   n_upsample=opt.n_upsample,
                   shared_block=shared_G)
    G2 = Generator(dim=opt.dim,
                   n_upsample=opt.n_upsample,
                   shared_block=shared_G)

    D1 = Discriminator(input_shape)
    D2 = Discriminator(input_shape)

    # Initialize weights
    E1.apply(weights_init_normal)
    E2.apply(weights_init_normal)
    G1.apply(weights_init_normal)
    G2.apply(weights_init_normal)
    D1.apply(weights_init_normal)
    D2.apply(weights_init_normal)

    # Loss function
    adversarial_loss = torch.nn.MSELoss()
    pixel_loss = torch.nn.L1Loss()

    if cuda:
        E1 = E1.cuda()
        E2 = E2.cuda()
        G1 = G1.cuda()
        G2 = G2.cuda()
        D1 = D1.cuda()
        D2 = D2.cuda()
        adversarial_loss = adversarial_loss.cuda()
        pixel_loss = pixel_loss.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(itertools.chain(E1.parameters(),
                                                   E2.parameters(),
                                                   G1.parameters(),
                                                   G2.parameters()),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))
    optimizer_D1 = torch.optim.Adam(D1.parameters(),
                                    lr=opt.lr,
                                    betas=(opt.b1, opt.b2))
    optimizer_D2 = torch.optim.Adam(D2.parameters(),
                                    lr=opt.lr,
                                    betas=(opt.b1, opt.b2))

    # Learning rate update schedulers
    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
        optimizer_G, lr_lambda=LambdaLR(opt.epochs, 0, opt.decay_epoch).step)
    lr_scheduler_D1 = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D1, lr_lambda=LambdaLR(opt.epochs, 0, opt.decay_epoch).step)
    lr_scheduler_D2 = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D2, lr_lambda=LambdaLR(opt.epochs, 0, opt.decay_epoch).step)

    prev_time = time.time()
    for epoch in range(opt.epochs):
        for i, (img_A, img_B) in enumerate(train_loader):

            # Model inputs
            X1 = Variable(img_A.type(FloatTensor))
            X2 = Variable(img_B.type(FloatTensor))

            # Adversarial ground truths
            valid = Variable(FloatTensor(img_A.shape[0],
                                         *D1.output_shape).fill_(1.0),
                             requires_grad=False)
            fake = Variable(FloatTensor(img_A.shape[0],
                                        *D1.output_shape).fill_(0.0),
                            requires_grad=False)

            # -----------------------------
            # Train Encoders and Generators
            # -----------------------------

            # Get shared latent representation
            mu1, Z1 = E1(X1)
            mu2, Z2 = E2(X2)

            # Reconstruct images
            recon_X1 = G1(Z1)
            recon_X2 = G2(Z2)

            # Translate images
            fake_X1 = G1(Z2)
            fake_X2 = G2(Z1)

            # Cycle translation
            mu1_, Z1_ = E1(fake_X1)
            mu2_, Z2_ = E2(fake_X2)
            cycle_X1 = G1(Z2_)
            cycle_X2 = G2(Z1_)

            # Losses for encoder and generator
            id_loss_1 = opt.lambda_id * pixel_loss(recon_X1, X1)
            id_loss_2 = opt.lambda_id * pixel_loss(recon_X2, X2)

            adv_loss_1 = opt.lambda_adv * adversarial_loss(D1(fake_X1), valid)
            adv_loss_2 = opt.lambda_adv * adversarial_loss(D2(fake_X2), valid)

            cyc_loss_1 = opt.lambda_cyc * pixel_loss(cycle_X1, X1)
            cyc_loss_2 = opt.lambda_cyc * pixel_loss(cycle_X2, X2)

            KL_loss_1 = opt.lambda_KL1 * compute_KL(mu1)
            KL_loss_2 = opt.lambda_KL1 * compute_KL(mu2)
            KL_loss_1_ = opt.lambda_KL2 * compute_KL(mu1_)
            KL_loss_2_ = opt.lambda_KL2 * compute_KL(mu2_)

            # total loss for encoder and generator
            G_loss = id_loss_1 + id_loss_2 \
                     + adv_loss_1 + adv_loss_2 \
                     + cyc_loss_1 + cyc_loss_2 + \
                     KL_loss_1 + KL_loss_2 + KL_loss_1_ + KL_loss_2_

            G_loss.backward()
            optimizer_G.step()

            # ----------------------
            # Train Discriminator 1
            # ----------------------

            optimizer_D1.zero_grad()

            D1_loss = adversarial_loss(D1(X1), valid) + adversarial_loss(
                D1(fake_X1.detach()), fake)
            D1_loss.backward()

            optimizer_D1.step()

            # ----------------------
            # Train Discriminator 2
            # ----------------------

            optimizer_D2.zero_grad()

            D2_loss = adversarial_loss(D2(X2), valid) + adversarial_loss(
                D2(fake_X2.detach()), fake)
            D2_loss.backward()

            optimizer_D2.step()

            # ------------------
            # Log Information
            # ------------------

            batches_done = epoch * len(train_loader) + i
            batches_left = opt.epochs * len(train_loader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] ETA: %s"
                % (epoch, opt.epochs, i, len(train_loader),
                   (D1_loss + D2_loss).item(), G_loss.item(), time_left))

            if batches_done % opt.sample_interval == 0:
                save_sample(opt.dataset, test_loader, batches_done, E1, E2, G1,
                            G2, FloatTensor)

            if batches_done % opt.checkpoint_interval == 0:
                torch.save(E1.state_dict(),
                           "checkpoints/%s/E1_%d.pth" % (opt.dataset, epoch))
                torch.save(E2.state_dict(),
                           "checkpoints/%s/E2_%d.pth" % (opt.dataset, epoch))
                torch.save(G1.state_dict(),
                           "checkpoints/%s/G1_%d.pth" % (opt.dataset, epoch))
                torch.save(G2.state_dict(),
                           "checkpoints/%s/G2_%d.pth" % (opt.dataset, epoch))

        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D1.step()
        lr_scheduler_D2.step()

    torch.save(shared_E.state_dict(),
               "checkpoints/%s/shared_E_done.pth" % opt.dataset)
    torch.save(shared_G.state_dict(),
               "checkpoints/%s/shared_G_done.pth" % opt.dataset)
    torch.save(E1.state_dict(), "checkpoints/%s/E1_done.pth" % opt.dataset)
    torch.save(E2.state_dict(), "checkpoints/%s/E2_done.pth" % opt.dataset)
    torch.save(G1.state_dict(), "checkpoints/%s/G1_done.pth" % opt.dataset)
    torch.save(G2.state_dict(), "checkpoints/%s/G2_done.pth" % opt.dataset)
    print("Training Process has been Done!")
Example #7
0
def train():
    os.makedirs("images", exist_ok=True)
    os.makedirs("checkpoints", exist_ok=True)

    opt = parse_args()
    cuda = True if torch.cuda.is_available() else False
    FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    # Calculate output of image discriminator (PatchGAN)
    patch_h, patch_w = int(opt.img_height / 2**4), int(opt.img_width / 2**4)
    patch = (1, patch_h, patch_w)

    # get dataloader
    train_loader = facades_loader(opt, mode='train')
    val_loader = facades_loader(opt, mode='val')

    # Initialize generator and discriminator
    generator = Generator()
    discriminator = Discriminator()

    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)

    # Loss function
    adversarial_loss = torch.nn.MSELoss()
    pixelwise_loss = torch.nn.L1Loss()

    if cuda:
        generator.cuda()
        discriminator.cuda()
        adversarial_loss.cuda()
        pixelwise_loss.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))

    prev_time = time.time()
    for epoch in range(opt.epochs):
        for i, (img_A, img_B) in enumerate(train_loader):

            # Model inputs
            img_A = Variable(img_A.type(FloatTensor))
            img_B = Variable(img_B.type(FloatTensor))

            # Adversarial ground truths
            valid = Variable(FloatTensor(img_A.shape[0], *patch).fill_(1.0),
                             requires_grad=False)
            fake = Variable(FloatTensor(img_A.shape[0], *patch).fill_(0.0),
                            requires_grad=False)

            # Configure input
            gen_imgs = generator(img_A)

            # ------------------
            # Train Generator
            # ------------------

            optimizer_G.zero_grad()

            # Loss for generator
            g_adv = adversarial_loss(discriminator(gen_imgs, img_A), valid)
            g_pixel = pixelwise_loss(gen_imgs, img_B)

            g_loss = g_adv + opt.lambda_pixel * g_pixel

            # Update parameters
            g_loss.backward()
            optimizer_G.step()

            # ------------------
            # Train Discriminator
            # ------------------

            optimizer_D.zero_grad()

            real_loss = adversarial_loss(discriminator(img_B, img_A), valid)
            fake_loss = adversarial_loss(
                discriminator(gen_imgs.detach(), img_A), fake)
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_D.step()

            # ------------------
            # Log Information
            # ------------------

            batches_done = epoch * len(train_loader) + i
            batches_left = opt.epochs * len(train_loader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G adv: %f, pixel: %f] ETA: %s"
                % (epoch, opt.epochs, i, len(train_loader), d_loss.item(),
                   g_adv.item(), g_pixel.item(), time_left))

            if batches_done % opt.sample_interval == 0:
                save_sample(val_loader, batches_done, generator, FloatTensor)

            if batches_done % opt.checkpoint_interval == 0:
                torch.save(generator.state_dict(),
                           "checkpoints/generator_%d.pth" % epoch)
                # torch.save(discriminator.state_dict(), "checkpoints/discriminator_%d.pth" % epoch)

    torch.save(generator.state_dict(), "checkpoints/generator_done.pth")
    print("Training Process has been Done!")
Example #8
0
def train():
    os.makedirs("images", exist_ok=True)
    os.makedirs("checkpoints", exist_ok=True)

    opt = parse_args()
    cuda = True if torch.cuda.is_available() else False
    FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    # Calculate output of image discriminator (PatchGAN)
    patch_h, patch_w = int(opt.mask_size / 2**3), int(opt.mask_size / 2**3)
    patch = (1, patch_h, patch_w)

    # get dataloader
    train_loader = celeba_loader(opt, mode='train')
    test_loader = celeba_loader(opt, mode='test')

    # Initialize generator and discriminator
    generator = Generator(opt.channels)
    discriminator = Discriminator(opt.channels)

    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)

    # Loss function
    adversarial_loss = torch.nn.MSELoss()
    pixelwise_loss = torch.nn.L1Loss()

    if cuda:
        generator.cuda()
        discriminator.cuda()
        adversarial_loss.cuda()
        pixelwise_loss.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))

    for epoch in range(opt.epochs):
        for i, (imgs, masked_imgs, masked_parts) in enumerate(train_loader):

            # Adversarial ground truths
            valid = Variable(FloatTensor(imgs.shape[0], *patch).fill_(1.0),
                             requires_grad=False)
            fake = Variable(FloatTensor(imgs.shape[0], *patch).fill_(0.0),
                            requires_grad=False)

            # Configure input
            imgs = Variable(imgs.type(FloatTensor))
            masked_imgs = Variable(masked_imgs.type(FloatTensor))
            masked_parts = Variable(masked_parts.type(FloatTensor))
            gen_parts = generator(masked_imgs)

            # ------------------
            # Train Discriminator
            # ------------------

            optimizer_D.zero_grad()

            # shape of masked_parts and valid [-1, 1, 8, 8]
            real_loss = adversarial_loss(discriminator(masked_parts), valid)
            fake_loss = adversarial_loss(discriminator(gen_parts.detach()),
                                         fake)
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_D.step()

            # ------------------
            # Train Generator
            # ------------------

            if i % opt.n_critic == 0:
                optimizer_G.zero_grad()

                # Loss for generator
                g_adv = adversarial_loss(discriminator(gen_parts), valid)
                g_pixel = pixelwise_loss(gen_parts, masked_parts)

                g_loss = 0.001 * g_adv + 0.999 * g_pixel

                # Update parameters
                g_loss.backward()
                optimizer_G.step()

            # ------------------
            # Log Information
            # ------------------

            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G adv: %f, pixel: %f]"
                % (epoch, opt.epochs, i, len(train_loader), d_loss.item(),
                   g_adv.item(), g_pixel.item()))

            batches_done = epoch * len(train_loader) + i
            if batches_done % opt.sample_interval == 0:
                save_sample(opt, test_loader, batches_done, generator,
                            FloatTensor)

            if batches_done % opt.checkpoint_interval == 0:
                torch.save(generator.state_dict(),
                           "checkpoints/generator_%d.pth" % epoch)
                # torch.save(discriminator.state_dict(), "checkpoints/discriminator_%d.pth" % epoch)

    torch.save(generator.state_dict(), "checkpoints/generator_done.pth")
    print("Training Process has been Done!")
Example #9
0
    parser.add_argument('--src_path', type=str, required=True)
    parser.add_argument('--tgt_path', type=str, required=True)
    parser.add_argument('--config_path', type=str, required=True)
    parser.add_argument('--ckpt_path', type=str, required=True)
    parser.add_argument('--output_dir', type=str, default='./outputs')
    args = parser.parse_args()

    params = get_config(args.config_path)

    output_dir = pathlib.Path(args.output_dir) / args.ckpt_path.split('/')[-2]

    if not output_dir.exists():
        output_dir.mkdir(parents=True)

    print('Build model')
    model = module_from_config(params)
    model = model.load_from_checkpoint(args.ckpt_path)
    model.freeze()

    print(model.hparams)

    print('Inference')
    wav = model(args.src_path, args.tgt_path)

    print('Saving')
    src_wav, tgt_wav = get_wav(args.src_path), get_wav(args.tgt_path)
    save_sample(str(output_dir / 'src.wav'), src_wav)
    save_sample(str(output_dir / 'tgt.wav'), tgt_wav)
    save_sample(str(output_dir / 'gen.wav'), wav)
    print('End')
Example #10
0
def train():
    opt = parse_args()

    os.makedirs("images/%s" % (opt.dataset), exist_ok=True)
    os.makedirs("checkpoints/%s" % (opt.dataset), exist_ok=True)

    cuda = True if torch.cuda.is_available() else False
    FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    # get dataloader
    train_loader = celeba_loader(opt, mode='train')
    val_loader = celeba_loader(opt, mode='val')

    # Dimensionality
    c_dim = len(opt.selected_attrs)

    # Initialize generator and discriminator
    generator = Generator(opt.channels, opt.residual_blocks, c_dim)
    discriminator = Discriminator(opt.channels, opt.img_height, c_dim)

    # Initialize weights
    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)

    # Loss function
    cycle_loss = torch.nn.L1Loss()

    if cuda:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        cycle_loss.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))

    # ------------
    #  Training
    # ------------

    prev_time = time.time()
    for epoch in range(opt.epochs):
        for i, (imgs, labels) in enumerate(train_loader):

            # Model inputs
            imgs = Variable(imgs.type(FloatTensor))
            labels = Variable(labels.type(FloatTensor))

            # Sample label as generator inputs and Generate fake batch of images
            sampled_c = Variable(
                FloatTensor(np.random.randint(0, 2, (imgs.size(0), c_dim))))
            fake_imgs = generator(imgs, sampled_c)

            # ----------------------
            # Train Discriminator
            # ----------------------

            optimizer_D.zero_grad()

            real_validity, pred_cls = discriminator(imgs)
            fake_validity, _ = discriminator(fake_imgs.detach())
            gradient_penalty = compute_gradient_penalty(
                discriminator, imgs.data, fake_imgs.data, FloatTensor)

            d_adv_loss = -torch.mean(real_validity) + torch.mean(
                fake_validity) + opt.lambda_gp * gradient_penalty
            d_cls_loss = criterion_cls(pred_cls, labels)
            D_loss = d_adv_loss + opt.lambda_cls * d_cls_loss

            D_loss.backward()
            optimizer_D.step()

            # -----------------------------
            # Train Generators
            # -----------------------------
            optimizer_G.zero_grad()

            if i % opt.n_critic == 0:
                gen_imgs = generator(imgs, sampled_c)
                recov_imgs = generator(gen_imgs, labels)

                fake_validity, pred_cls = discriminator(gen_imgs)

                g_adv_loss = -torch.mean(fake_validity)
                g_cls_loss = criterion_cls(pred_cls, sampled_c)
                g_rec_loss = cycle_loss(recov_imgs, imgs)
                G_loss = g_adv_loss + opt.lambda_cls * g_cls_loss + opt.lambda_rec * g_rec_loss

                G_loss.backward()
                optimizer_G.step()

                # ------------------
                # Log Information
                # ------------------

                batches_done = epoch * len(train_loader) + i
                batches_left = opt.epochs * len(train_loader) - batches_done
                time_left = datetime.timedelta(seconds=batches_left *
                                               (time.time() - prev_time))
                prev_time = time.time()

                print(
                    "[Epoch %d/%d] [Batch %d/%d] [D loss: %f, aux: %f] [G loss: %f, aux: %f, cycle: %f] ETA: %s"
                    % (epoch, opt.epochs, i, len(train_loader), D_loss.item(),
                       d_cls_loss.item(), G_loss.item(), g_cls_loss.item(),
                       g_rec_loss, time_left))

                if batches_done % opt.sample_interval == 0:
                    save_sample(opt.dataset, val_loader, batches_done,
                                generator, FloatTensor)

                if batches_done % opt.checkpoint_interval == 0:
                    torch.save(
                        Generator.state_dict(),
                        "checkpoints/%s/G_%d.pth" % (opt.dataset, epoch))

    torch.save(Generator.state_dict(),
               "checkpoints/%s/shared_E_done.pth" % opt.dataset)
    print("Training Process has been Done!")