예제 #1
0
def main():
    # define actor/critic/discriminator net and optimizer
    policy = Policy(discrete_action_sections, discrete_state_sections)
    value = Value()
    discriminator = Discriminator()
    optimizer_policy = torch.optim.Adam(policy.parameters(), lr=args.policy_lr)
    optimizer_value = torch.optim.Adam(value.parameters(), lr=args.value_lr)
    optimizer_discriminator = torch.optim.Adam(discriminator.parameters(),
                                               lr=args.discrim_lr)
    discriminator_criterion = nn.BCELoss()
    writer = SummaryWriter()

    # load expert data
    dataset = ExpertDataSet(args.expert_activities_data_path,
                            args.expert_cost_data_path)
    data_loader = data.DataLoader(dataset=dataset,
                                  batch_size=args.expert_batch_size,
                                  shuffle=False,
                                  num_workers=1)

    # load models
    # discriminator.load_state_dict(torch.load('./model_pkl/Discriminator_model_3.pkl'))
    # policy.transition_net.load_state_dict(torch.load('./model_pkl/Transition_model_3.pkl'))
    # policy.policy_net.load_state_dict(torch.load('./model_pkl/Policy_model_3.pkl'))
    # value.load_state_dict(torch.load('./model_pkl/Value_model_3.pkl'))

    print('#############  start training  ##############')

    # update discriminator
    num = 0
    for ep in tqdm(range(args.training_epochs)):
        # collect data from environment for ppo update
        start_time = time.time()
        memory = policy.collect_samples(args.ppo_buffer_size, size=10000)
        # print('sample_data_time:{}'.format(time.time()-start_time))
        batch = memory.sample()
        continuous_state = torch.stack(
            batch.continuous_state).squeeze(1).detach()
        discrete_action = torch.stack(
            batch.discrete_action).squeeze(1).detach()
        continuous_action = torch.stack(
            batch.continuous_action).squeeze(1).detach()
        next_discrete_state = torch.stack(
            batch.next_discrete_state).squeeze(1).detach()
        next_continuous_state = torch.stack(
            batch.next_continuous_state).squeeze(1).detach()
        old_log_prob = torch.stack(batch.old_log_prob).detach()
        mask = torch.stack(batch.mask).squeeze(1).detach()
        discrete_state = torch.stack(batch.discrete_state).squeeze(1).detach()
        d_loss = torch.empty(0, device=device)
        p_loss = torch.empty(0, device=device)
        v_loss = torch.empty(0, device=device)
        gen_r = torch.empty(0, device=device)
        expert_r = torch.empty(0, device=device)
        for _ in range(1):
            for expert_state_batch, expert_action_batch in data_loader:
                gen_state = torch.cat((discrete_state, continuous_state),
                                      dim=-1)
                gen_action = torch.cat((discrete_action, continuous_action),
                                       dim=-1)
                gen_r = discriminator(gen_state, gen_action)
                expert_r = discriminator(expert_state_batch,
                                         expert_action_batch)
                optimizer_discriminator.zero_grad()
                d_loss = discriminator_criterion(gen_r,
                                                 torch.zeros(gen_r.shape, device=device)) + \
                         discriminator_criterion(expert_r,
                                                 torch.ones(expert_r.shape, device=device))
                total_d_loss = d_loss - 10 * torch.var(gen_r.to(device))
                d_loss.backward()
                # total_d_loss.backward()
                optimizer_discriminator.step()
        writer.add_scalar('d_loss', d_loss, ep)
        # writer.add_scalar('total_d_loss', total_d_loss, ep)
        writer.add_scalar('expert_r', expert_r.mean(), ep)

        # update PPO
        gen_r = discriminator(
            torch.cat((discrete_state, continuous_state), dim=-1),
            torch.cat((discrete_action, continuous_action), dim=-1))
        optimize_iter_num = int(
            math.ceil(discrete_state.shape[0] / args.ppo_mini_batch_size))
        for ppo_ep in range(args.ppo_optim_epoch):
            for i in range(optimize_iter_num):
                num += 1
                index = slice(
                    i * args.ppo_mini_batch_size,
                    min((i + 1) * args.ppo_mini_batch_size,
                        discrete_state.shape[0]))
                discrete_state_batch, continuous_state_batch, discrete_action_batch, continuous_action_batch, \
                old_log_prob_batch, mask_batch, next_discrete_state_batch, next_continuous_state_batch, gen_r_batch = \
                    discrete_state[index], continuous_state[index], discrete_action[index], continuous_action[index], \
                    old_log_prob[index], mask[index], next_discrete_state[index], next_continuous_state[index], gen_r[
                        index]
                v_loss, p_loss = ppo_step(
                    policy, value, optimizer_policy, optimizer_value,
                    discrete_state_batch, continuous_state_batch,
                    discrete_action_batch, continuous_action_batch,
                    next_discrete_state_batch, next_continuous_state_batch,
                    gen_r_batch, old_log_prob_batch, mask_batch,
                    args.ppo_clip_epsilon)
            writer.add_scalar('p_loss', p_loss, num)
            writer.add_scalar('v_loss', v_loss, num)
            writer.add_scalar('gen_r', gen_r.mean(), num)

        print('#' * 5 + 'training episode:{}'.format(ep) + '#' * 5)
        print('d_loss', d_loss.item())
        # print('p_loss', p_loss.item())
        # print('v_loss', v_loss.item())
        print('gen_r:', gen_r.mean().item())
        print('expert_r:', expert_r.mean().item())

        memory.clear_memory()
        # save models
        torch.save(discriminator.state_dict(),
                   './model_pkl/Discriminator_model_4.pkl')
        torch.save(policy.transition_net.state_dict(),
                   './model_pkl/Transition_model_4.pkl')
        torch.save(policy.policy_net.state_dict(),
                   './model_pkl/Policy_model_4.pkl')
        torch.save(value.state_dict(), './model_pkl/Value_model_4.pkl')
예제 #2
0
                writer.add_images(f'img_{depth}/const', y, global_idx)

                writer.flush()
        # save after every epoch
        saves = glob.glob(f'logs/{logs_idx}/*.pt')
        if len(saves) == 10:
            saves.sort(key=os.path.getmtime)
            os.remove(saves[0])
        if epoch == epoch_size - 1:
            torch.save(
                {
                    'depth': depth + 1,
                    'epoch': 0,
                    'global_idx': global_idx,
                    'generator_state_dict': generator.state_dict(),
                    'discriminator_state_dict': discriminator.state_dict(),
                    'g_optimizer_state_dict': g_optimizer.state_dict(),
                    'd_optimizer_state_dict': d_optimizer.state_dict()
                }, f'logs/{logs_idx}/model_{depth + 1}_{0}.pt')
        else:
            torch.save(
                {
                    'depth': depth,
                    'epoch': epoch + 1,
                    'global_idx': global_idx,
                    'generator_state_dict': generator.state_dict(),
                    'discriminator_state_dict': discriminator.state_dict(),
                    'g_optimizer_state_dict': g_optimizer.state_dict(),
                    'd_optimizer_state_dict': d_optimizer.state_dict()
                }, f'logs/{logs_idx}/model_{depth}_{epoch + 1}.pt')
    # reset startpoint
예제 #3
0
        writer.add_scalar('score_fake',
                          torch.mean(score_wrong_logit).cpu().data.numpy(),
                          global_step=global_step)

        # pass


global_step = 1
while data_iter.epoch <= epoch_total:
    # train
    loss, score_real_logit, score_wrong_logit = train_dis()
    # train_select()

    # print(global_step)
    if (global_step % display_step) == 0:
        display(global_step, data_iter.epoch, loss, score_real_logit,
                score_wrong_logit)

    if data_iter.epoch % save_step == 0:
        if data_iter.epoch == 0:
            continue
        if os.path.isfile(
                os.path.join(save_path, 'save-fea-%d' % data_iter.epoch)):
            continue
        torch.save(feature_extractor.state_dict(),
                   os.path.join(save_path, 'save-fea-%d' % data_iter.epoch))
        torch.save(dis.state_dict(),
                   os.path.join(save_path, 'save-dis-%d' % data_iter.epoch))

    global_step += 1
예제 #4
0
for epoch in tqdm(range(epoch_num)):
    log_loss_G, log_loss_D = train_dcgan(model_G, model_D, params_G, params_D, data_loader)

    print("{}/{} epoch finished".format(epoch + 1, epoch_num))

    with open("loss.log", mode="a") as f:
        f.write("{}/{} epoch G_loss : {}, D_loss : {} \n".format(epoch+1, epoch_num, log_loss_G, log_loss_D))
    print("G_loss : {}, D_loss : {}".format(log_loss_G, log_loss_D))

    with open("loss_G.log", mode="a") as f:
        f.write(str(log_loss_G) + "\n")

    with open("loss_D.log", mode="a") as f:
        f.write(str(log_loss_D) + "\n")   

    # 訓練途中のモデル・生成画像の保存
    if epoch % 3 == 0:
        torch.save(
            model_G.state_dict(),
            "Weight_Generator/" + datetime.datetime.now(timezone('Asia/Tokyo')).strftime("%Y_%m_%d_%H_%M_%S") + "_G_{:03d}.prm".format(epoch),
            pickle_protocol=4)
        torch.save(
            model_D.state_dict(),
            "Weight_Discriminator/" + datetime.datetime.now(timezone('Asia/Tokyo')).strftime("%Y_%m_%d_%H_%M_%S") + "_D_{:03d}.prm".format(epoch),
            pickle_protocol=4)

        generated_img = model_G(check_z)
        save_image(generated_img,
                   "Generated_Image/" + datetime.datetime.now(timezone('Asia/Tokyo')).strftime("%Y_%m_%d_%H_%M_%S") + "{:03d}.jpg".format(epoch))
        #save the weights and images every 200th iteration of an epoch
        if i % 200 == 0:
            for j, test_data in enumerate(test_loader):
                gen.eval()

                test_hazy_images, test_clear_images = test_data
                vutils.save_image(test_clear_images,
                                  results_path + 'real_samples.png',
                                  normalize=True)

                fake = gen(test_hazy_images)
                vutils.save_image(
                    fake.data,
                    results_path +
                    'fake_samples_epoch{}_{}_{}.png'.format(e, i, j),
                    normalize=True)
                gen.train()

                #save batches
                torch.save(
                    {
                        'epochs': e,
                        'gen_state_dict': gen.state_dict(),
                        'disc_state_dict': disc.state_dict(),
                        'optimizerG_state_dict': optimizerG.state_dict(),
                        'optimizerD_state_dict': optimizerD.state_dict(),
                        'errorG': errorG,
                        'errorD': errorD
                    }, checkpoint_tar)

                break
예제 #6
0
class Trial:
    def __init__(self,
                 data_dir: str = './dataset',
                 log_dir: str = './logs',
                 device: str = "cuda:0",
                 batch_size: int = 2,
                 init_lr: float = 0.5,
                 G_lr: float = 0.0004,
                 D_lr: float = 0.0008,
                 level: str = "O1",
                 patch: bool = False,
                 init_training_epoch: int = 10,
                 train_epoch: int = 10,
                 optim_type: str = "ADAM",
                 pin_memory: bool = True,
                 grad_set_to_none: bool = True):

        # self.config = config
        self.data_dir = data_dir

        self.dataset = Dataset(root=data_dir + "/Shinkai",
                               style_transform=tr.transform,
                               smooth_transform=tr.transform)

        self.pin_memory = pin_memory
        self.batch_size = batch_size

        self.dataloader = DataLoader(self.dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=4,
                                     pin_memory=pin_memory)

        self.device = torch.device(
            device) if torch.cuda.is_available() else torch.device('cpu')

        self.G = Generator().to(self.device)
        self.patch = patch
        if self.patch:
            self.D = PatchDiscriminator().to(self.device)
        else:
            self.D = Discriminator().to(self.device)

        self.init_model_weights()

        self.optimizer_G = GANOptimizer(optim_type,
                                        self.G.parameters(),
                                        lr=G_lr,
                                        betas=(0.5, 0.999),
                                        amsgrad=False)
        self.optimizer_D = GANOptimizer(optim_type,
                                        self.D.parameters(),
                                        lr=D_lr,
                                        betas=(0.5, 0.999),
                                        amsgrad=True)

        self.loss = Loss(device=self.device).to(self.device)

        self.init_lr = init_lr
        self.G_lr = G_lr
        self.D_lr = D_lr
        self.grad_set_to_none = grad_set_to_none

        self.writer = tensorboard.SummaryWriter(log_dir=log_dir)
        self.init_train_epoch = init_training_epoch
        self.train_epoch = train_epoch

        self.init_time = None
        self.level = level

        if self.level != "O0" and device != "cpu":
            self.fp16 = True
            [self.G,
             self.D], [self.optimizer_G, self.optimizer_D
                       ] = amp.initialize([self.G, self.D],
                                          [self.optimizer_G, self.optimizer_D],
                                          opt_level=self.level)
        else:
            self.fp16 = False

    def init_model_weights(self):
        self.G.apply(weights_init)
        self.D.apply(weights_init)

    @classmethod
    def from_config(cls):
        pass

    def init_train(self, con_weight: float = 1.0):

        test_img = self.get_test_image()
        meter = AverageMeter("Loss")
        self.writer.flush()
        lr_scheduler = OneCycleLR(self.optimizer_G,
                                  max_lr=0.9999,
                                  steps_per_epoch=len(self.dataloader),
                                  epochs=self.init_train_epoch)

        for g in self.optimizer_G.param_groups:
            g['lr'] = self.init_lr

        for epoch in tqdm(range(self.init_train_epoch)):

            meter.reset()

            for i, (style, smooth, train) in enumerate(self.dataloader, 0):
                # train = transform(test_img).unsqueeze(0)
                self.G.zero_grad(set_to_none=self.grad_set_to_none)
                train = train.to(self.device)

                generator_output = self.G(train)
                # content_loss = loss.reconstruction_loss(generator_output, train) * con_weight
                content_loss = self.loss.content_loss(generator_output,
                                                      train) * con_weight
                # content_loss = F.mse_loss(train, generator_output) * con_weight
                content_loss.backward()
                self.optimizer_G.step()
                lr_scheduler.step()

                meter.update(content_loss.detach())

            self.writer.add_scalar(f"Loss : {self.init_time}",
                                   meter.sum.item(), epoch)
            self.write_weights(epoch + 1, write_D=False)
            self.eval_image(epoch, f'{self.init_time} reconstructed img',
                            test_img)

        for g in self.optimizer_G.param_groups:
            g['lr'] = self.G_lr

        # self.save_trial(self.init_train_epoch, "init")

    def eval_image(self, epoch: int, caption, img):
        """Feeds in one single image to process and save."""
        self.G.eval()
        styled_test_img = tr.transform(img).unsqueeze(0).to(self.device)
        with torch.no_grad():
            styled_test_img = self.G(styled_test_img)
            styled_test_img = styled_test_img.to('cpu').squeeze()
        self.write_image(styled_test_img, caption, epoch + 1)
        self.writer.flush()
        self.G.train()

    def write_image(self,
                    image: torch.Tensor,
                    img_caption: str = "sample_image",
                    step: int = 0):

        image = torch.clip(tr.inv_norm(image).to(torch.float), 0,
                           1)  # [-1, 1] -> [0, 1]
        image *= 255.  # [0, 1] -> [0, 255]
        image = image.permute(1, 2, 0).to(dtype=torch.uint8)
        self.writer.add_image(img_caption, image, step, dataformats='HWC')
        self.writer.flush()

    def write_weights(self, epoch: int, write_D=True, write_G=True):

        if write_D:
            for name, weight in self.D.named_parameters():
                if 'depthwise' in name or 'pointwise' in name:
                    self.writer.add_histogram(
                        f"Discriminator {name} {self.init_time}", weight,
                        epoch)
                    self.writer.add_histogram(
                        f"Discriminator {name}.grad {self.init_time}",
                        weight.grad, epoch)
                    self.writer.flush()

        if write_G:
            for name, weight in self.G.named_parameters():
                self.writer.add_histogram(f"Generator {name} {self.init_time}",
                                          weight, epoch)
                self.writer.add_histogram(
                    f"Generator {name}.grad {self.init_time}", weight.grad,
                    epoch)
                self.writer.flush()

    def train_1(
        self,
        adv_weight: float = 300.,
        con_weight: float = 1.5,
        gra_weight: float = 3.,
        col_weight: float = 10.,
    ):

        test_img_dir = Path(
            self.data_dir).joinpath('test/test_photo256').resolve()
        test_img_dir = random.choice(list(test_img_dir.glob('**/*')))
        test_img = Image.open(test_img_dir)
        self.writer.add_image(f'test image {self.init_time}',
                              np.asarray(test_img),
                              dataformats='HWC')
        self.writer.flush()

        for epoch in tqdm(range(self.train_epoch)):

            for i, (style, smooth, train) in enumerate(self.dataloader, 0):

                self.D.zero_grad()
                style = style.to(self.device)
                smooth = smooth.to(self.device)
                train = train.to(self.device)

                # style image to discriminator(Not Gram Matrix Loss)
                style_loss_value = self.D(style).view(-1)
                generator_output = self.G(train)
                # generated image to discriminator
                real_output = self.D(generator_output.detach()).view(-1)
                # greyscale_output = D(transforms.functional.rgb_to_grayscale(train, num_output_channels=3)).view(-1) #greyscale adversarial loss
                gray_train = tr.inv_gray_transform(train)
                greyscale_output = self.D(gray_train).view(-1)
                smoothed_loss = self.D(smooth).view(-1)  # smoothed image loss
                # loss_D_real = adversarial_loss(output, label)

                dis_adv_loss = adv_weight * (
                    torch.pow(style_loss_value - 1, 2).mean() +
                    torch.pow(real_output, 2).mean())
                dis_gray_loss = torch.pow(greyscale_output, 2).mean()
                dis_edge_loss = torch.pow(smoothed_loss, 2).mean()
                discriminator_loss = dis_adv_loss + dis_gray_loss + dis_edge_loss
                discriminator_loss.backward()
                self.optimizer_D.step()

                if i % 200 == 0 and i != 0:
                    self.writer.add_scalars(
                        f'{self.init_time} Discriminator losses', {
                            'adversarial loss': dis_adv_loss.item(),
                            'grayscale loss': dis_gray_loss.item(),
                            'edge loss': dis_edge_loss.item()
                        }, i + epoch * len(self.dataloader))
                    self.writer.flush()

                real_output = self.D(generator_output).view(-1)
                per_loss = self.loss.perceptual_loss(
                    train, generator_output)  # loss for G
                style_loss = self.loss.style_loss(generator_output, style)
                content_loss = self.loss.content_loss(generator_output, train)
                recon_loss = self.loss.reconstruction_loss(
                    generator_output, train)
                tv_loss = self.loss.tv_loss(generator_output)
                '''
                print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                    % (epoch, num_epoch, i, len(data_loader),
                      loss_D.item(), loss_G.item(), D_x, D_G_z1, D_G_z2))'''

                self.G.zero_grad()
                gen_adv_loss = adv_weight * torch.pow(real_output - 1,
                                                      2).mean()
                gen_con_loss = con_weight * content_loss
                gen_sty_loss = gra_weight * style_loss
                gen_rec_loss = col_weight * recon_loss
                gen_per_loss = per_loss
                gen_tv_loss = tv_loss
                generator_loss = gen_adv_loss + gen_con_loss + gen_sty_loss + gen_rec_loss + gen_per_loss
                generator_loss.backward()
                self.optimizer_G.step()

                if i % 200 == 0 and i != 0:

                    self.writer.add_scalars(
                        f'generator losses {self.init_time}', {
                            'adversarial loss': gen_adv_loss.item(),
                            'content loss': gen_con_loss.item(),
                            'style loss': gen_sty_loss.item(),
                            'reconstruction loss': gen_rec_loss.item(),
                            'perceptual loss': gen_per_loss.item()
                        }, i + epoch * len(self.dataloader))
                    self.writer.flush()

            self.write_weights(epoch + 1)
            self.eval_image(epoch, f'{self.init_time} style img', test_img)

    def train_2(self,
                adv_weight: float = 1.0,
                threshold: float = 3.,
                G_train_iter: int = 1,
                D_train_iter: int = 1
                ):  # if threshold is 0., set to half of adversarial loss

        test_img_dir = Path(self.data_dir).joinpath('test', 'test_photo256')
        test_img_dir = random.choice(list(test_img_dir.glob('**/*')))
        test_img = Image.open(test_img_dir)

        if self.init_time is None:
            self.init_time = datetime.datetime.now().strftime("%H:%M")

        self.writer.add_image(f'sample_image {self.init_time}',
                              np.asarray(test_img),
                              dataformats='HWC')
        self.writer.flush()

        perception_weight = 0.
        keep_constant = False

        for epoch in tqdm(range(self.train_epoch)):

            total_dis_loss = 0.

            for i, (style, smooth, train) in enumerate(self.dataloader, 0):

                self.D.zero_grad()

                train = train.to(self.device)
                style = style.to(self.device)
                # smooth = smooth.to(device)

                for _ in range(D_train_iter):
                    style_loss_value = self.D(style).view(-1)
                    generator_output = self.G(train)
                    real_output = self.D(generator_output.detach()).view(-1)
                    dis_adv_loss = adv_weight * \
                        (torch.pow(style_loss_value - 1, 2).mean() + torch.pow(real_output, 2).mean())
                    total_dis_loss += dis_adv_loss.item()
                    dis_adv_loss.backward()
                self.optimizer_D.step()

                self.G.zero_grad()
                for _ in range(G_train_iter):
                    generator_output = self.G(train)
                    real_output = self.D(generator_output).view(-1)
                    per_loss = perception_weight * \
                        self.loss.perceptual_loss(train, generator_output)
                    gen_adv_loss = adv_weight * torch.pow(real_output - 1,
                                                          2).mean()
                    gen_loss = gen_adv_loss + per_loss
                    gen_loss.backward()
                self.optimizer_G.step()

                if i % 200 == 0 and i != 0:
                    self.writer.add_scalars(
                        f'generator losses  {self.init_time}', {
                            'adversarial loss': dis_adv_loss.item(),
                            'Generator adversarial loss': gen_adv_loss.item(),
                            'perceptual loss': per_loss.item()
                        }, i + epoch * len(self.dataloader))
                    self.writer.flush()

            if total_dis_loss > threshold and not keep_constant:
                perception_weight += 0.05
            else:
                keep_constant = True

            self.writer.add_scalar(
                f'total discriminator loss {self.init_time}', total_dis_loss,
                i + epoch * len(self.dataloader))

            self.write_weights()
            self.G.eval()

            styled_test_img = tr.transform(test_img).unsqueeze(0).to(
                self.device)
            with torch.no_grad():
                styled_test_img = self.G(styled_test_img)

            styled_test_img = styled_test_img.to('cpu').squeeze()
            self.write_image(styled_test_img, f'styled image {self.init_time}',
                             epoch + 1)

            self.G.train()

    def __call__(self):
        self.init_train()
        self.train_1()

    def save_trial(self, epoch: int, train_type: str):
        save_dir = Path(f"{train_type}_{self.level}.pt")
        training_details = {
            "epoch": epoch,
            "gen": {
                "gen_state_dict": self.G.state_dict(),
                "optim_G_state_dict": self.optimizer_G.state_dict()
            },
            "dis": {
                "dis_state_dict": self.D.state_dict(),
                "optim_D_state_dict": self.optimizer_D.state_dict()
            }
        }
        if self.fp16:
            training_details["amp"] = amp.state_dict()

        torch.save(training_details, save_dir.as_posix())

    def load_trial(self, dir: Path):
        assert dir.is_file(), "No such directory"
        assert dir.suffix == ".pt", "Filetype not compatible"
        state_dicts = torch.load(dir.as_posix())
        self.G.load_state_dict(state_dicts["gen"]["gen_state_dict"])
        self.optimizer_G.load_state_dict(
            state_dicts["gen"]["optim_G_state_dict"])
        self.D.load_state_dict(state_dicts["dis"]["dis_state_dict"])
        self.optimizer_D.load_state_dict(
            state_dicts["dis"]["optim_D_state_dict"])
        if self.fp16:
            amp.load_state_dict(state_dicts["amp"])
        typer.echo("Loaded Weights")

    def Generator_NOGAN(self,
                        epochs: int = 1,
                        style_weight: float = 20.,
                        content_weight: float = 1.2,
                        recon_weight: float = 10.,
                        tv_weight: float = 1e-6,
                        loss: List[str] = ['content_loss']):
        """Training Generator in NOGAN manner (Feature Loss only)."""
        for g in self.optimizer_G.param_groups:
            g['lr'] = self.G_lr
        test_img = self.get_test_image()
        max_lr = self.G_lr * 10.

        lr_scheduler = OneCycleLR(self.optimizer_G,
                                  max_lr=max_lr,
                                  steps_per_epoch=len(self.dataloader),
                                  epochs=epochs)

        meter = LossMeters(*loss)
        total_loss_arr = np.array([])

        for epoch in tqdm(range(epochs)):

            total_losses = 0
            meter.reset()

            for i, (style, smooth, train) in enumerate(self.dataloader, 0):
                # train = transform(test_img).unsqueeze(0)
                self.G.zero_grad(set_to_none=self.grad_set_to_none)
                train = train.to(self.device)

                generator_output = self.G(train)
                if 'style_loss' in loss:
                    style = style.to(self.device)
                    style_loss = self.loss.style_loss(generator_output,
                                                      style) * style_weight
                else:
                    style_loss = 0.

                if 'content_loss' in loss:
                    content_loss = self.loss.content_loss(
                        generator_output, train) * content_weight
                else:
                    content_loss = 0.

                if 'recon_loss' in loss:
                    recon_loss = self.loss.reconstruction_loss(
                        generator_output, train) * recon_weight
                else:
                    recon_loss = 0.

                if 'tv_loss' in loss:
                    tv_loss = self.loss.tv_loss(generator_output) * tv_weight
                else:
                    tv_loss = 0.

                total_loss = content_loss + tv_loss + recon_loss + style_loss
                if self.fp16:
                    with amp.scale_loss(total_loss,
                                        self.optimizer_G) as scaled_loss:
                        scaled_loss.backward()
                else:
                    total_loss.backward()

                self.optimizer_G.step()
                lr_scheduler.step()
                total_losses += total_loss.detach()
                loss_dict = {
                    'content_loss': content_loss,
                    'style_loss': style_loss,
                    'recon_loss': recon_loss,
                    'tv_loss': tv_loss
                }

                losses = [loss_dict[loss_type].detach() for loss_type in loss]
                meter.update(*losses)

            total_loss_arr = np.append(total_loss_arr, total_losses.item())
            self.writer.add_scalars(f'{self.init_time} NOGAN generator losses',
                                    meter.as_dict('sum'), epoch)

            self.write_weights(epoch + 1, write_D=False)
            self.eval_image(epoch, f'{self.init_time} reconstructed img',
                            test_img)
            if epoch > 2:
                fig = plt.figure(figsize=(8, 8))
                X = np.arange(len(total_loss_arr))
                Y = np.gradient(total_loss_arr)
                plt.plot(X, Y)
                thresh = -1.0
                plt.axhline(thresh, c='r')
                plt.title(f"{self.init_time}")
                self.writer.add_figure(f"{self.init_time}", fig, epoch)
                if Y[-1] > thresh:
                    break

        self.save_trial(epoch, f'G_NG_{self.init_time}')

    def Discriminator_NOGAN(
            self,
            epochs: int = 3,
            adv_weight: float = 1.0,
            edge_weight: float = 1.0,
            loss: List[str] = ['real_adv_loss', 'fake_adv_loss', 'gray_loss']):
        """https://discuss.pytorch.org/t/scheduling-batch-size-in-dataloader/46443/2"""

        for g in self.optimizer_D.param_groups:
            g['lr'] = self.D_lr

        max_lr = self.D_lr * 10.
        lr_scheduler = OneCycleLR(self.optimizer_D,
                                  max_lr=max_lr,
                                  steps_per_epoch=len(self.dataloader),
                                  epochs=epochs)
        meter = LossMeters(*loss)
        total_loss_arr = np.array([])
        if self.init_time is None:
            self.init_time = datetime.datetime.now().strftime("%H:%M")

        for epoch in tqdm(range(epochs)):

            meter.reset()

            for i, (style, smooth, train) in enumerate(self.dataloader, 0):
                # train = transform(test_img).unsqueeze(0)
                self.D.zero_grad(set_to_none=self.grad_set_to_none)
                train = train.to(self.device)
                style = style.to(self.device)

                generator_output = self.G(train)
                real_adv_loss = self.D(style).view(-1)
                fake_adv_loss = self.D(generator_output.detach()).view(-1)
                real_adv_loss = torch.pow(real_adv_loss - 1,
                                          2).mean() * 1.7 * adv_weight
                fake_adv_loss = torch.pow(fake_adv_loss,
                                          2).mean() * 1.7 * adv_weight
                gray_train = tr.inv_gray_transform(style)
                greyscale_output = self.D(gray_train).view(-1)
                gray_loss = torch.pow(greyscale_output,
                                      2).mean() * 1.7 * adv_weight
                "According to AnimeGANv2 implementation, every loss is scaled by individual weights and then scaled with adv_weight"
                "https://github.com/TachibanaYoshino/AnimeGANv2/blob/5946b6afcca5fc28518b75a763c0f561ff5ce3d6/tools/ops.py#L217"
                total_loss = real_adv_loss + fake_adv_loss + gray_loss
                if self.fp16:
                    with amp.scale_loss(total_loss,
                                        self.optimizer_D) as scaled_loss:
                        scaled_loss.backward()
                else:
                    total_loss.backward()
                self.optimizer_D.step()
                lr_scheduler.step()

                loss_dict = {
                    'real_adv_loss': real_adv_loss,
                    'fake_adv_loss': fake_adv_loss,
                    'gray_loss': gray_loss
                }

                losses = [loss_dict[loss_type].detach() for loss_type in loss]
                meter.update(*losses)

            self.writer.add_scalars(
                f'{self.init_time} NOGAN discriminator loss',
                meter.as_dict('sum'), epoch)
            self.writer.flush()
            if epoch > 2:
                fig = plt.figure(figsize=(8, 8))
                X = np.arange(len(total_loss_arr))
                Y = np.gradient(total_loss_arr)
                plt.plot(X, Y)
                thresh = -1.0
                plt.axhline(thresh, c='r')
                plt.title(f"{self.init_time}")
                self.writer.add_figure(f"{self.init_time}", fig, epoch)
                if Y[-1] > thresh:
                    break

    def GAN_NOGAN(
        self,
        epochs: int = 1,
        GAN_G_lr: float = 0.00008,
        GAN_D_lr: float = 0.000016,
        D_loss: List[str] = [
            "real_adv_loss", "fake_adv_loss", "gray_loss", "edge_loss"
        ],
        adv_weight: float = 300.,
        edge_weight: float = 0.1,
        G_loss: List[str] = [
            "adv_loss", "content_loss", "style_loss", "recon_loss"
        ],
        style_weight: float = 20.,
        content_weight: float = 1.2,
        recon_weight: float = 10.,
        tv_weight: float = 1e-6,
    ):

        test_img = self.get_test_image()
        dis_meter = LossMeters(*D_loss)
        gen_meter = LossMeters(*G_loss)

        for g in self.optimizer_G.param_groups:
            g['lr'] = GAN_G_lr

        for g in self.optimizer_D.param_groups:
            g['lr'] = GAN_D_lr

        update_duration = len(self.dataloader) // 20

        for epoch in tqdm(range(epochs)):

            G_loss_arr = np.array([])
            dis_meter.reset()
            count = 0

            for i, (style, smooth, train) in enumerate(self.dataloader, 0):
                self.D.zero_grad(set_to_none=self.grad_set_to_none)
                train = train.to(self.device)
                style = style.to(self.device)
                smooth = smooth.to(self.device)

                generator_output = self.G(train)
                real_adv_loss = self.D(style).view(-1)
                fake_adv_loss = self.D(generator_output.detach()).view(-1)
                G_adv_loss = self.D(generator_output).view(-1)
                gray_train = tr.inv_gray_transform(style)
                grayscale_output = self.D(gray_train).view(-1)
                gray_smooth_data = tr.inv_gray_transform(smooth)
                smoothed_output = self.D(smooth).view(-1)

                real_adv_loss = torch.square(real_adv_loss -
                                             1.).mean() * 1.7 * adv_weight
                fake_adv_loss = torch.square(
                    fake_adv_loss).mean() * 1.7 * adv_weight
                gray_loss = torch.square(
                    grayscale_output).mean() * 1.7 * adv_weight
                edge_loss = torch.square(
                    smoothed_output).mean() * 1.0 * adv_weight

                total_D_loss = real_adv_loss + fake_adv_loss + gray_loss + edge_loss
                total_D_loss.backward()
                self.optimizer_D.step()

                D_loss_dict = {
                    'real_adv_loss': real_adv_loss,
                    'fake_adv_loss': fake_adv_loss,
                    'gray_loss': gray_loss,
                    'edge_loss': edge_loss
                }

                loss = list(D_loss_dict.values())

                dis_meter.update(*loss)

                if i % update_duration == 0 and i != 0:
                    self.writer.add_scalars(f'{self.init_time} NOGAN Dis loss',
                                            dis_meter.as_dict('val'),
                                            i + epoch * len(self.dataloader))
                    self.writer.flush()

                self.G.zero_grad(set_to_none=self.grad_set_to_none)
                G_adv_loss = torch.square(G_adv_loss - 1.).mean() * adv_weight

                if 'style_loss' in G_loss:
                    style_loss = self.loss.style_loss(generator_output,
                                                      style) * style_weight
                else:
                    style_loss = 0.

                if 'content_loss' in G_loss:
                    content_loss = self.loss.content_loss(
                        generator_output, train) * content_weight
                else:
                    content_loss = 0.

                if 'recon_loss' in G_loss:
                    recon_loss = self.loss.reconstruction_loss(
                        generator_output, train) * recon_weight
                else:
                    recon_loss = 0.

                if 'tv_loss' in G_loss:
                    tv_loss = self.loss.tv_loss(generator_output) * tv_weight
                else:
                    tv_loss = 0.

                total_G_loss = G_adv_loss + content_loss + tv_loss + recon_loss + style_loss
                total_G_loss.backward()
                self.optimizer_G.step()

                G_loss_dict = {
                    'adv_loss': G_adv_loss,
                    'content_loss': content_loss,
                    'style_loss': style_loss,
                    'recon_loss': recon_loss,
                    'tv_loss': tv_loss
                }

                losses = [
                    G_loss_dict[loss_type].detach() for loss_type in G_loss
                ]
                gen_meter.update(*losses)

                if i % update_duration == 0 and i != 0:
                    self.writer.add_scalars(f'{self.init_time} NOGAN Gen loss',
                                            gen_meter.as_dict('val'),
                                            i + epoch * len(self.dataloader))
                    self.writer.flush()
                    G_loss_arr = np.append(G_loss_arr, G_adv_loss.item())
                    self.eval_image(i + epoch * len(self.dataloader),
                                    f'{self.init_time} reconstructed img',
                                    test_img)

        self.save_trial(epoch, f'GAN_NG_{self.init_time}')

    def get_test_image(self):
        """Get random test image."""
        test_img_dir = Path(self.data_dir).joinpath('test/test_photo256')
        test_img_dir = random.choice(list(test_img_dir.glob('**/*')))
        test_img = Image.open(test_img_dir)
        self.init_time = datetime.datetime.now().strftime("%H:%M")
        self.writer.add_image(f'{self.init_time} sample_image',
                              np.asarray(test_img),
                              dataformats='HWC')
        self.writer.flush()
        return test_img
예제 #7
0
                           torch.sigmoid(logit_fake)).data.cpu().numpy()[0],
                       torch.mean(reward).data.cpu().numpy()[0]))

                # w
                if writer is not None:
                    d = {
                        'loss_g': loss_g.data,
                        'loss_d': loss_d.data,
                        'logit_real': torch.mean(logit_real).data,
                        'real': torch.mean(torch.sigmoid(logit_real)).data,
                        'logit_fake': torch.mean(logit_fake).data,
                        'fake': torch.mean(torch.sigmoid(logit_fake)).data,
                        'reward': torch.mean(reward).data
                    }
                    # writer.add_scalars('print', d, epoch_done * K * M + subset_idx * M + dist_idx)
                    writer.add_scalar(
                        'a',
                        loss_g.data.cpu().numpy(),
                        epoch_done * K * M + subset_idx * M + dist_idx)

        # save each subset
        torch.save(
            dis.state_dict(),
            os.path.join(save_path,
                         'save-dis_%d-%d' % (epoch_done, subset_idx)))
        torch.save(
            selector.state_dict(),
            os.path.join(save_path,
                         'save-sel_%d-%d' % (epoch_done, subset_idx)))
        print('Saved! Epoch: %d, subset_idx: %d' % (epoch_done, subset_idx))
예제 #8
0
class Trainer(nn.Module):
    def __init__(self, model_dir, g_optimizer, d_optimizer, lr, warmup,
                 max_iters):
        super().__init__()
        self.model_dir = model_dir
        if not os.path.exists(f'checkpoints/{model_dir}'):
            os.makedirs(f'checkpoints/{model_dir}')
        self.logs_dir = f'checkpoints/{model_dir}/logs'
        if not os.path.exists(self.logs_dir):
            os.makedirs(self.logs_dir)
        self.writer = SummaryWriter(self.logs_dir)

        self.arcface = ArcFaceNet(50, 0.6, 'ir_se').cuda()
        self.arcface.eval()
        self.arcface.load_state_dict(torch.load(
            'checkpoints/model_ir_se50.pth', map_location='cuda'),
                                     strict=False)

        self.mobiface = MobileFaceNet(512).cuda()
        self.mobiface.eval()
        self.mobiface.load_state_dict(torch.load(
            'checkpoints/mobilefacenet.pth', map_location='cuda'),
                                      strict=False)

        self.generator = Generator().cuda()
        self.discriminator = Discriminator().cuda()

        self.adversarial_weight = 1
        self.src_id_weight = 5
        self.tgt_id_weight = 1
        self.attributes_weight = 10
        self.reconstruction_weight = 10

        self.lr = lr
        self.warmup = warmup
        self.g_optimizer = g_optimizer(self.generator.parameters(),
                                       lr=lr,
                                       betas=(0, 0.999))
        self.d_optimizer = d_optimizer(self.discriminator.parameters(),
                                       lr=lr,
                                       betas=(0, 0.999))

        self.generator, self.g_optimizer = amp.initialize(self.generator,
                                                          self.g_optimizer,
                                                          opt_level="O1")
        self.discriminator, self.d_optimizer = amp.initialize(
            self.discriminator, self.d_optimizer, opt_level="O1")

        self._iter = nn.Parameter(torch.tensor(1), requires_grad=False)
        self.max_iters = max_iters

        if torch.cuda.is_available():
            self.cuda()

    @property
    def iter(self):
        return self._iter.item()

    @property
    def device(self):
        return next(self.parameters()).device

    def adapt(self, args):
        device = self.device
        return [arg.to(device) for arg in args]

    def train_loop(self, dataloaders, eval_every, generate_every, save_every):
        for batch in tqdm(dataloaders['train']):
            torch.Tensor.add_(self._iter, 1)
            # generator step
            # if self.iter % 2 == 0:
            # self.adjust_lr(self.g_optimizer)
            g_losses = self.g_step(self.adapt(batch))
            g_stats = self.get_opt_stats(self.g_optimizer, type='generator')
            self.write_logs(losses=g_losses, stats=g_stats, type='generator')

            # #discriminator step
            # if self.iter % 2 == 1:
            # self.adjust_lr(self.d_optimizer)
            d_losses = self.d_step(self.adapt(batch))
            d_stats = self.get_opt_stats(self.d_optimizer,
                                         type='discriminator')
            self.write_logs(losses=d_losses,
                            stats=d_stats,
                            type='discriminator')

            if self.iter % eval_every == 0:
                discriminator_acc = self.evaluate_discriminator_accuracy(
                    dataloaders['val'])
                identification_acc = self.evaluate_identification_similarity(
                    dataloaders['val'])
                metrics = {**discriminator_acc, **identification_acc}
                self.write_logs(metrics=metrics)

            if self.iter % generate_every == 0:
                self.generate(*self.adapt(batch))

            if self.iter % save_every == 0:
                self.save_discriminator()
                self.save_generator()

    def g_step(self, batch):
        self.generator.train()
        self.g_optimizer.zero_grad()
        L_adv, L_src_id, L_tgt_id, L_attr, L_rec, L_generator = self.g_loss(
            *batch)
        with amp.scale_loss(L_generator, self.g_optimizer) as scaled_loss:
            scaled_loss.backward()
        self.g_optimizer.step()

        losses = {
            'adv': L_adv.item(),
            'src_id': L_src_id.item(),
            'tgt_id': L_tgt_id.item(),
            'attributes': L_attr.item(),
            'reconstruction': L_rec.item(),
            'total_loss': L_generator.item()
        }
        return losses

    def d_step(self, batch):
        self.discriminator.train()
        self.d_optimizer.zero_grad()
        L_fake, L_real, L_discriminator = self.d_loss(*batch)
        with amp.scale_loss(L_discriminator, self.d_optimizer) as scaled_loss:
            scaled_loss.backward()
        self.d_optimizer.step()

        losses = {
            'hinge_fake': L_fake.item(),
            'hinge_real': L_real.item(),
            'total_loss': L_discriminator.item()
        }
        return losses

    def g_loss(self, Xs, Xt, same_person):
        with torch.no_grad():
            src_embed = self.arcface(
                F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112],
                              mode='bilinear',
                              align_corners=True))
            tgt_embed = self.arcface(
                F.interpolate(Xt[:, :, 19:237, 19:237], [112, 112],
                              mode='bilinear',
                              align_corners=True))

        Y_hat, Xt_attr = self.generator(Xt, src_embed, return_attributes=True)

        Di = self.discriminator(Y_hat)

        L_adv = 0
        for di in Di:
            L_adv += hinge_loss(di[0], True)

        fake_embed = self.arcface(
            F.interpolate(Y_hat[:, :, 19:237, 19:237], [112, 112],
                          mode='bilinear',
                          align_corners=True))
        L_src_id = (
            1 - torch.cosine_similarity(src_embed, fake_embed, dim=1)).mean()
        L_tgt_id = (
            1 - torch.cosine_similarity(tgt_embed, fake_embed, dim=1)).mean()

        batch_size = Xs.shape[0]
        Y_hat_attr = self.generator.get_attr(Y_hat)
        L_attr = 0
        for i in range(len(Xt_attr)):
            L_attr += torch.mean(torch.pow(Xt_attr[i] - Y_hat_attr[i],
                                           2).reshape(batch_size, -1),
                                 dim=1).mean()
        L_attr /= 2.0

        L_rec = torch.sum(
            0.5 * torch.mean(torch.pow(Y_hat - Xt, 2).reshape(batch_size, -1),
                             dim=1) * same_person) / (same_person.sum() + 1e-6)
        L_generator = (self.adversarial_weight *
                       L_adv) + (self.src_id_weight * L_src_id) + (
                           self.tgt_id_weight *
                           L_tgt_id) + (self.attributes_weight * L_attr) + (
                               self.reconstruction_weight * L_rec)
        return L_adv, L_src_id, L_tgt_id, L_attr, L_rec, L_generator

    def d_loss(self, Xs, Xt, same_person):
        with torch.no_grad():
            src_embed = self.arcface(
                F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112],
                              mode='bilinear',
                              align_corners=True))
        Y_hat = self.generator(Xt, src_embed, return_attributes=False)

        fake_D = self.discriminator(Y_hat.detach())
        L_fake = 0
        for di in fake_D:
            L_fake += hinge_loss(di[0], False)
        real_D = self.discriminator(Xs)
        L_real = 0
        for di in real_D:
            L_real += hinge_loss(di[0], True)

        L_discriminator = 0.5 * (L_real + L_fake)
        return L_fake, L_real, L_discriminator

    def evaluate_discriminator_accuracy(self, val_dataloader):
        real_acc = 0
        fake_acc = 0
        self.generator.eval()
        self.discriminator.eval()
        for batch in tqdm(val_dataloader):
            Xs, Xt, _ = self.adapt(batch)

            with torch.no_grad():
                embed = self.arcface(
                    F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112],
                                  mode='bilinear',
                                  align_corners=True))
                Y_hat = self.generator(Xt, embed, return_attributes=False)
                fake_D = self.discriminator(Y_hat)
                real_D = self.discriminator(Xs)

            fake_multiscale_acc = 0
            for di in fake_D:
                fake_multiscale_acc += torch.mean((di[0] < 0).float())
            fake_acc += fake_multiscale_acc / len(fake_D)

            real_multiscale_acc = 0
            for di in real_D:
                real_multiscale_acc += torch.mean((di[0] > 0).float())
            real_acc += real_multiscale_acc / len(real_D)

        self.generator.train()
        self.discriminator.train()

        metrics = {
            'fake_acc': 100 * (fake_acc / len(val_dataloader)).item(),
            'real_acc': 100 * (real_acc / len(val_dataloader)).item()
        }
        return metrics

    def evaluate_identification_similarity(self, val_dataloader):
        src_id_sim = 0
        tgt_id_sim = 0
        self.generator.eval()
        for batch in tqdm(val_dataloader):
            Xs, Xt, _ = self.adapt(batch)
            with torch.no_grad():
                src_embed = self.arcface(
                    F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112],
                                  mode='bilinear',
                                  align_corners=True))
                Y_hat = self.generator(Xt, src_embed, return_attributes=False)

                src_embed = self.mobiface(
                    F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112],
                                  mode='bilinear',
                                  align_corners=True))
                tgt_embed = self.mobiface(
                    F.interpolate(Xt[:, :, 19:237, 19:237], [112, 112],
                                  mode='bilinear',
                                  align_corners=True))
                fake_embed = self.mobiface(
                    F.interpolate(Y_hat[:, :, 19:237, 19:237], [112, 112],
                                  mode='bilinear',
                                  align_corners=True))

            src_id_sim += (torch.cosine_similarity(src_embed,
                                                   fake_embed,
                                                   dim=1)).float().mean()
            tgt_id_sim += (torch.cosine_similarity(tgt_embed,
                                                   fake_embed,
                                                   dim=1)).float().mean()

        self.generator.train()

        metrics = {
            'src_similarity': 100 * (src_id_sim / len(val_dataloader)).item(),
            'tgt_similarity': 100 * (tgt_id_sim / len(val_dataloader)).item()
        }
        return metrics

    def generate(self, Xs, Xt, same_person):
        def get_grid_image(X):
            X = X[:8]
            X = torchvision.utils.make_grid(X.detach().cpu(), nrow=X.shape[0])
            X = (X * 0.5 + 0.5) * 255
            return X

        def make_image(Xs, Xt, Y_hat):
            Xs = get_grid_image(Xs)
            Xt = get_grid_image(Xt)
            Y_hat = get_grid_image(Y_hat)
            return torch.cat((Xs, Xt, Y_hat), dim=1).numpy()

        with torch.no_grad():
            embed = self.arcface(
                F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112],
                              mode='bilinear',
                              align_corners=True))
            self.generator.eval()
            Y_hat = self.generator(Xt, embed, return_attributes=False)
            self.generator.train()

        image = make_image(Xs, Xt, Y_hat)
        if not os.path.exists(f'results/{self.model_dir}'):
            os.makedirs(f'results/{self.model_dir}')
        cv2.imwrite(f'results/{self.model_dir}/{self.iter}.jpg',
                    image.transpose([1, 2, 0]))

    def get_opt_stats(self, optimizer, type=''):
        stats = {f'{type}_lr': optimizer.param_groups[0]['lr']}
        return stats

    def adjust_lr(self, optimizer):
        if self.iter <= self.warmup:
            lr = self.lr * self.iter / self.warmup
        else:
            lr = self.lr * (1 + cos(pi * (self.iter - self.warmup) /
                                    (self.max_iters - self.warmup))) / 2

        for group in optimizer.param_groups:
            group['lr'] = lr
        return lr

    def write_logs(self, losses=None, metrics=None, stats=None, type='loss'):
        if losses:
            for name, value in losses.items():
                self.writer.add_scalar(f'{type}/{name}', value, self.iter)
        if metrics:
            for name, value in metrics.items():
                self.writer.add_scalar(f'metric/{name}', value, self.iter)
        if stats:
            for name, value in stats.items():
                self.writer.add_scalar(f'stats/{name}', value, self.iter)

    def save_generator(self, max_checkpoints=100):
        checkpoints = glob.glob(f'{self.model_dir}/*.pt')
        if len(checkpoints) > max_checkpoints:
            os.remove(checkpoints[-1])
        with open(f'checkpoints/{self.model_dir}/generator_{self.iter}.pt',
                  'wb') as f:
            torch.save(self.generator.state_dict(), f)

    def save_discriminator(self, max_checkpoints=100):
        checkpoints = glob.glob(f'{self.model_dir}/*.pt')
        if len(checkpoints) > max_checkpoints:
            os.remove(checkpoints[-1])
        with open(f'checkpoints/{self.model_dir}/discriminator_{self.iter}.pt',
                  'wb') as f:
            torch.save(self.discriminator.state_dict(), f)

    def load_discriminator(self, path, load_last=True):
        if load_last:
            try:
                checkpoints = glob.glob(f'{path}/discriminator*.pt')
                path = max(checkpoints, key=os.path.getctime)
            except (ValueError):
                print(f'Directory is empty: {path}')

        try:
            self.discriminator.load_state_dict(torch.load(path))
            self.cuda()
        except (FileNotFoundError):
            print(f'No such file: {path}')

    def load_generator(self, path, load_last=True):
        if load_last:
            try:
                checkpoints = glob.glob(f'{path}/generator*.pt')
                path = max(checkpoints, key=os.path.getctime)
            except (ValueError):
                print(f'Directory is empty: {path}')

        try:
            self.generator.load_state_dict(torch.load(path))
            iter_str = ''.join(filter(lambda x: x.isdigit(), path))
            self._iter = nn.Parameter(torch.tensor(int(iter_str)),
                                      requires_grad=False)
            self.cuda()
        except (FileNotFoundError):
            print(f'No such file: {path}')
예제 #9
0
파일: main.py 프로젝트: FengHZ/DCGAN
def main(args=args):
    dataset_base_path = path.join(args.base_path, "dataset", "celeba")
    image_base_path = path.join(dataset_base_path, "img_align_celeba")
    split_dataset_path = path.join(dataset_base_path, "Eval", "list_eval_partition.txt")
    with open(split_dataset_path, "r") as f:
        split_annotation = f.read().splitlines()
    # create the data name list for train,test and valid
    train_data_name_list = []
    test_data_name_list = []
    valid_data_name_list = []
    for item in split_annotation:
        item = item.split(" ")
        if item[1] == '0':
            train_data_name_list.append(item[0])
        elif item[1] == '1':
            valid_data_name_list.append(item[0])
        else:
            test_data_name_list.append(item[0])
    attribute_annotation_dict = None
    if args.need_label:
        attribute_annotation_path = path.join(dataset_base_path, "Anno", "list_attr_celeba.txt")
        with open(attribute_annotation_path, "r") as f:
            attribute_annotation = f.read().splitlines()
        attribute_annotation = attribute_annotation[2:]
        attribute_annotation_dict = {}
        for item in attribute_annotation:
            img_name, attribute = item.split(" ", 1)
            attribute = tuple([eval(attr) for attr in attribute.split(" ") if attr != ""])
            assert len(attribute) == 40, "the attribute of item {} is not equal to 40".format(img_name)
            attribute_annotation_dict[img_name] = attribute
    discriminator = Discriminator(num_channel=args.num_channel, num_feature=args.dnf,
                                  data_parallel=args.data_parallel).cuda()
    generator = Generator(latent_dim=args.latent_dim, num_feature=args.gnf,
                          num_channel=args.num_channel, data_parallel=args.data_parallel).cuda()
    input("Begin the {} time's training, the train dataset has {} images and the valid has {} images".format(
        args.train_time, len(train_data_name_list), len(valid_data_name_list)))
    d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.beta1, 0.999))
    g_optimizer = torch.optim.Adam(generator.parameters(), lr=args.lr, betas=(args.beta1, 0.999))
    d_scheduler = ExponentialLR(d_optimizer, gamma=args.decay_lr)
    g_scheduler = ExponentialLR(g_optimizer, gamma=args.decay_lr)
    writer_log_dir = "{}/DCGAN/runs/train_time:{}".format(args.base_path, args.train_time)
    # Here we implement the resume part
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            if not args.not_resume_arg:
                args = checkpoint['args']
                args.start_epoch = checkpoint['epoch']
            discriminator.load_state_dict(checkpoint["discriminator_state_dict"])
            generator.load_state_dict(checkpoint["generator_state_dict"])
            d_optimizer.load_state_dict(checkpoint['discriminator_optimizer'])
            g_optimizer.load_state_dict(checkpoint['generator_optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            raise FileNotFoundError("Checkpoint Resume File {} Not Found".format(args.resume))
    else:
        if os.path.exists(writer_log_dir):
            flag = input("DCGAN train_time:{} will be removed, input yes to continue:".format(
                args.train_time))
            if flag == "yes":
                shutil.rmtree(writer_log_dir, ignore_errors=True)
    writer = SummaryWriter(log_dir=writer_log_dir)
    # Here we just use the train dset in training
    train_dset = CelebADataset(base_path=image_base_path, data_name_list=train_data_name_list,
                               image_size=args.image_size,
                               label_dict=attribute_annotation_dict)
    train_dloader = DataLoader(dataset=train_dset, batch_size=args.batch_size, shuffle=True,
                               num_workers=args.workers, pin_memory=True)
    criterion = nn.BCELoss()
    for epoch in range(args.start_epoch, args.epochs):
        train(train_dloader, generator, discriminator, g_optimizer, d_optimizer, criterion, writer, epoch)
        # adjust lr
        d_scheduler.step()
        g_scheduler.step()
        # save parameters
        save_checkpoint({
            'epoch': epoch + 1,
            'args': args,
            "discriminator_state_dict": discriminator.state_dict(),
            "generator_state_dict": generator.state_dict(),
            'discriminator_optimizer': d_optimizer.state_dict(),
            'generator_optimizer': g_optimizer.state_dict()
        })
예제 #10
0
def main():
    ## load std models
    # policy_log_std = torch.load('./model_pkl/policy_net_action_std_model_1.pkl')
    # transition_log_std = torch.load('./model_pkl/transition_net_state_std_model_1.pkl')

    # load expert data
    print(args.data_set_path)
    dataset = ExpertDataSet(args.data_set_path)
    data_loader = data.DataLoader(dataset=dataset,
                                  batch_size=args.expert_batch_size,
                                  shuffle=True,
                                  num_workers=0)
    # define actor/critic/discriminator net and optimizer
    policy = Policy(onehot_action_sections,
                    onehot_state_sections,
                    state_0=dataset.state)
    value = Value()
    discriminator = Discriminator()
    optimizer_policy = torch.optim.Adam(policy.parameters(), lr=args.policy_lr)
    optimizer_value = torch.optim.Adam(value.parameters(), lr=args.value_lr)
    optimizer_discriminator = torch.optim.Adam(discriminator.parameters(),
                                               lr=args.discrim_lr)
    discriminator_criterion = nn.BCELoss()
    if write_scalar:
        writer = SummaryWriter(log_dir='runs/' + model_name)

    # load net  models
    if load_model:
        discriminator.load_state_dict(
            torch.load('./model_pkl/Discriminator_model_' + model_name +
                       '.pkl'))
        policy.transition_net.load_state_dict(
            torch.load('./model_pkl/Transition_model_' + model_name + '.pkl'))
        policy.policy_net.load_state_dict(
            torch.load('./model_pkl/Policy_model_' + model_name + '.pkl'))
        value.load_state_dict(
            torch.load('./model_pkl/Value_model_' + model_name + '.pkl'))

        policy.policy_net_action_std = torch.load(
            './model_pkl/Policy_net_action_std_model_' + model_name + '.pkl')
        policy.transition_net_state_std = torch.load(
            './model_pkl/Transition_net_state_std_model_' + model_name +
            '.pkl')
    print('#############  start training  ##############')

    # update discriminator
    num = 0
    for ep in tqdm(range(args.training_epochs)):
        # collect data from environment for ppo update
        policy.train()
        value.train()
        discriminator.train()
        start_time = time.time()
        memory, n_trajs = policy.collect_samples(
            batch_size=args.sample_batch_size)
        # print('sample_data_time:{}'.format(time.time()-start_time))
        batch = memory.sample()
        onehot_state = torch.cat(batch.onehot_state, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        multihot_state = torch.cat(batch.multihot_state, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        continuous_state = torch.cat(batch.continuous_state, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()

        onehot_action = torch.cat(batch.onehot_action, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        multihot_action = torch.cat(batch.multihot_action, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        continuous_action = torch.cat(batch.continuous_action, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        next_onehot_state = torch.cat(batch.next_onehot_state, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        next_multihot_state = torch.cat(batch.next_multihot_state,
                                        dim=1).reshape(
                                            n_trajs * args.sample_traj_length,
                                            -1).detach()
        next_continuous_state = torch.cat(
            batch.next_continuous_state,
            dim=1).reshape(n_trajs * args.sample_traj_length, -1).detach()

        old_log_prob = torch.cat(batch.old_log_prob, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        mask = torch.cat(batch.mask,
                         dim=1).reshape(n_trajs * args.sample_traj_length,
                                        -1).detach()
        gen_state = torch.cat((onehot_state, multihot_state, continuous_state),
                              dim=-1)
        gen_action = torch.cat(
            (onehot_action, multihot_action, continuous_action), dim=-1)
        if ep % 1 == 0:
            # if (d_slow_flag and ep % 50 == 0) or (not d_slow_flag and ep % 1 == 0):
            d_loss = torch.empty(0, device=device)
            p_loss = torch.empty(0, device=device)
            v_loss = torch.empty(0, device=device)
            gen_r = torch.empty(0, device=device)
            expert_r = torch.empty(0, device=device)
            for expert_state_batch, expert_action_batch in data_loader:
                noise1 = torch.normal(0,
                                      args.noise_std,
                                      size=gen_state.shape,
                                      device=device)
                noise2 = torch.normal(0,
                                      args.noise_std,
                                      size=gen_action.shape,
                                      device=device)
                noise3 = torch.normal(0,
                                      args.noise_std,
                                      size=expert_state_batch.shape,
                                      device=device)
                noise4 = torch.normal(0,
                                      args.noise_std,
                                      size=expert_action_batch.shape,
                                      device=device)
                gen_r = discriminator(gen_state + noise1, gen_action + noise2)
                expert_r = discriminator(
                    expert_state_batch.to(device) + noise3,
                    expert_action_batch.to(device) + noise4)

                # gen_r = discriminator(gen_state, gen_action)
                # expert_r = discriminator(expert_state_batch.to(device), expert_action_batch.to(device))
                optimizer_discriminator.zero_grad()
                d_loss = discriminator_criterion(gen_r, torch.zeros(gen_r.shape, device=device)) + \
                            discriminator_criterion(expert_r,torch.ones(expert_r.shape, device=device))
                variance = 0.5 * torch.var(gen_r.to(device)) + 0.5 * torch.var(
                    expert_r.to(device))
                total_d_loss = d_loss - 10 * variance
                d_loss.backward()
                # total_d_loss.backward()
                optimizer_discriminator.step()
            if write_scalar:
                writer.add_scalar('d_loss', d_loss, ep)
                writer.add_scalar('total_d_loss', total_d_loss, ep)
                writer.add_scalar('variance', 10 * variance, ep)
        if ep % 1 == 0:
            # update PPO
            noise1 = torch.normal(0,
                                  args.noise_std,
                                  size=gen_state.shape,
                                  device=device)
            noise2 = torch.normal(0,
                                  args.noise_std,
                                  size=gen_action.shape,
                                  device=device)
            gen_r = discriminator(gen_state + noise1, gen_action + noise2)
            #if gen_r.mean().item() < 0.1:
            #    d_stop = True
            #if d_stop and gen_r.mean()
            optimize_iter_num = int(
                math.ceil(onehot_state.shape[0] / args.ppo_mini_batch_size))
            # gen_r = -(1 - gen_r + 1e-10).log()
            for ppo_ep in range(args.ppo_optim_epoch):
                for i in range(optimize_iter_num):
                    num += 1
                    index = slice(
                        i * args.ppo_mini_batch_size,
                        min((i + 1) * args.ppo_mini_batch_size,
                            onehot_state.shape[0]))
                    onehot_state_batch, multihot_state_batch, continuous_state_batch, onehot_action_batch, multihot_action_batch, continuous_action_batch, \
                    old_log_prob_batch, mask_batch, next_onehot_state_batch, next_multihot_state_batch, next_continuous_state_batch, gen_r_batch = \
                        onehot_state[index], multihot_state[index], continuous_state[index], onehot_action[index], multihot_action[index], continuous_action[index], \
                        old_log_prob[index], mask[index], next_onehot_state[index], next_multihot_state[index], next_continuous_state[index], gen_r[
                            index]
                    v_loss, p_loss = ppo_step(
                        policy, value, optimizer_policy, optimizer_value,
                        onehot_state_batch, multihot_state_batch,
                        continuous_state_batch, onehot_action_batch,
                        multihot_action_batch, continuous_action_batch,
                        next_onehot_state_batch, next_multihot_state_batch,
                        next_continuous_state_batch, gen_r_batch,
                        old_log_prob_batch, mask_batch, args.ppo_clip_epsilon)
                    if write_scalar:
                        writer.add_scalar('p_loss', p_loss, ep)
                        writer.add_scalar('v_loss', v_loss, ep)
        policy.eval()
        value.eval()
        discriminator.eval()
        noise1 = torch.normal(0,
                              args.noise_std,
                              size=gen_state.shape,
                              device=device)
        noise2 = torch.normal(0,
                              args.noise_std,
                              size=gen_action.shape,
                              device=device)
        gen_r = discriminator(gen_state + noise1, gen_action + noise2)
        expert_r = discriminator(
            expert_state_batch.to(device) + noise3,
            expert_action_batch.to(device) + noise4)
        gen_r_noise = gen_r.mean().item()
        expert_r_noise = expert_r.mean().item()
        gen_r = discriminator(gen_state, gen_action)
        expert_r = discriminator(expert_state_batch.to(device),
                                 expert_action_batch.to(device))
        if write_scalar:
            writer.add_scalar('gen_r', gen_r.mean(), ep)
            writer.add_scalar('expert_r', expert_r.mean(), ep)
            writer.add_scalar('gen_r_noise', gen_r_noise, ep)
            writer.add_scalar('expert_r_noise', expert_r_noise, ep)
        print('#' * 5 + 'training episode:{}'.format(ep) + '#' * 5)
        print('gen_r_noise', gen_r_noise)
        print('expert_r_noise', expert_r_noise)
        print('gen_r:', gen_r.mean().item())
        print('expert_r:', expert_r.mean().item())
        print('d_loss', d_loss.item())
        # save models
        if model_name is not None:
            torch.save(
                discriminator.state_dict(),
                './model_pkl/Discriminator_model_' + model_name + '.pkl')
            torch.save(policy.transition_net.state_dict(),
                       './model_pkl/Transition_model_' + model_name + '.pkl')
            torch.save(policy.policy_net.state_dict(),
                       './model_pkl/Policy_model_' + model_name + '.pkl')
            torch.save(
                policy.policy_net_action_std,
                './model_pkl/Policy_net_action_std_model_' + model_name +
                '.pkl')
            torch.save(
                policy.transition_net_state_std,
                './model_pkl/Transition_net_state_std_model_' + model_name +
                '.pkl')
            torch.save(value.state_dict(),
                       './model_pkl/Value_model_' + model_name + '.pkl')
        memory.clear_memory()
예제 #11
0
            print('------------------------------')

            _, (sampled, s1, s2) = validate(2, args.use_cuda, need_samples=True)

            for i in range(len(sampled)):
                result = generator.sample_with_pair(batch_loader, 20, args.use_cuda, s1[i], s2[i])

                print('source: ' + s1[i])
                print('target: ' + s2[i])
                print('sampled: ' + result)
                print('...........................')

        # save model
        if (iteration % 10000 == 0 and iteration != 0) or iteration == (args.num_iterations - 1):
            t.save(generator.state_dict(), 'saved_models/trained_generator_' + args.model_name + '_' + iteration//1000)
            t.save(discriminator.state_dict(), 'saved_models/trained_discrminator_' + args.model_name + '_' +  iteration//1000)
            np.save('logs/{}/ce_result_valid.npy'.format(args.model_name), np.array(ce_result_valid))
            np.save('logs/{}/ce_result_train.npy'.format(args.model_name), np.array(ce_result_train))
            np.save('logs/{}/kld_result_valid'.format(args.model_name), np.array(kld_result_valid))
            np.save('logs/{}/kld_result_train'.format(args.model_name), np.array(kld_result_train))
            np.save('logs/{}/ce2_result_valid.npy'.format(args.model_name), np.array(ce2_result_valid))
            np.save('logs/{}/ce2_result_train.npy'.format(args.model_name), np.array(ce2_result_train))
            np.save('logs/{}/dg_result_valid.npy'.format(args.model_name), np.array(dg_result_valid))
            np.save('logs/{}/dg_result_train.npy'.format(args.model_name), np.array(dg_result_train))
            np.save('logs/{}/d_result_valid.npy'.format(args.model_name), np.array(d_result_valid))
            np.save('logs/{}/d_result_train.npy'.format(args.model_name), np.array(d_result_train))

        #interm sampling
        if (iteration % 10000 == 0 and iteration != 0) or iteration == (args.num_iterations - 1):
            if args.interm_sampling:
                args.seq_len = 30
def main(data_dir):
    origin_img, uv_map_gt, uv_map_predicted = None, None, None

    if not os.path.exists(FLAGS['images']):
        os.mkdir(FLAGS['images'])

    # 1) Create Dataset of 300_WLP & Dataloader.
    wlp300 = PRNetDataset(root_dir=data_dir,
                          transform=transforms.Compose([
                              ToTensor(),
                              ToResize((416, 416)),
                              ToNormalize(FLAGS["normalize_mean"],
                                          FLAGS["normalize_std"])
                          ]))

    wlp300_dataloader = DataLoader(dataset=wlp300,
                                   batch_size=FLAGS['batch_size'],
                                   shuffle=True,
                                   num_workers=1)

    # 2) Intermediate Processing.
    transform_img = transforms.Compose([
        #transforms.ToTensor(),
        transforms.Normalize(FLAGS["normalize_mean"], FLAGS["normalize_std"])
    ])

    # 3) Create PRNet model.
    start_epoch, target_epoch = FLAGS['start_epoch'], FLAGS['target_epoch']
    g_x = ResFCN256(resolution_input=416,
                    resolution_output=416,
                    channel=3,
                    size=16)
    g_y = ResFCN256(resolution_input=416,
                    resolution_output=416,
                    channel=3,
                    size=16)
    d_x = Discriminator()
    d_y = Discriminator()

    # Load the pre-trained weight
    if FLAGS['resume'] != "" and os.path.exists(
            os.path.join(FLAGS['pretrained'], FLAGS['resume'])):
        state = torch.load(os.path.join(FLAGS['pretrained'], FLAGS['resume']))
        try:
            g_x.load_state_dict(state['g_x'])
            g_y.load_state_dict(state['g_y'])
            d_x.load_state_dict(state['d_x'])
            d_y.load_state_dict(state['d_y'])
        except Exception:
            g_x.load_state_dict(state['prnet'])
        start_epoch = state['start_epoch']
        INFO("Load the pre-trained weight! Start from Epoch", start_epoch)
    else:
        start_epoch = 0
        INFO(
            "Pre-trained weight cannot load successfully, train from scratch!")

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    g_x.to(FLAGS["device"])
    g_y.to(FLAGS["device"])
    d_x.to(FLAGS["device"])
    d_y.to(FLAGS["device"])

    optimizer_g = torch.optim.Adam(itertools.chain(g_x.parameters(),
                                                   g_y.parameters()),
                                   lr=FLAGS["lr"],
                                   betas=(0.5, 0.999))
    optimizer_d = torch.optim.Adam(itertools.chain(d_x.parameters(),
                                                   d_y.parameters()),
                                   lr=FLAGS["lr"])
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.99)

    stat_loss = SSIM(mask_path=FLAGS["mask_path"], gauss=FLAGS["gauss_kernel"])
    loss = WeightMaskLoss(mask_path=FLAGS["mask_path"])
    bce_loss = torch.nn.BCEWithLogitsLoss()
    bce_loss.to(FLAGS["device"])
    l1_loss = nn.L1Loss().to(FLAGS["device"])
    lambda_X = 10
    lambda_Y = 10
    #Loss function for adversarial
    for ep in range(start_epoch, target_epoch):
        bar = tqdm(wlp300_dataloader)
        loss_list_cycle_x = []
        loss_list_cycle_y = []
        loss_list_d_x = []
        loss_list_d_y = []
        real_label = torch.ones(FLAGS['batch_size'])
        fake_label = torch.zeros(FLAGS['batch_size'])
        for i, sample in enumerate(bar):
            real_y, real_x = sample['uv_map'].to(
                FLAGS['device']), sample['origin'].to(FLAGS['device'])
            # x -> y' -> x^
            optimizer_g.zero_grad()
            fake_y = g_x(real_x)
            prediction = d_x(fake_y)
            loss_g_x = bce_loss(prediction, real_label)
            x_hat = g_y(fake_y)
            loss_cycle_x = l1_loss(x_hat, real_x) * lambda_X
            loss_x = loss_g_x + loss_cycle_x
            loss_x.backward(retain_graph=True)
            optimizer_g.step()
            loss_list_cycle_x.append(loss_x.item())
            # y -> x' -> y^
            optimizer_g.zero_grad()
            fake_x = g_y(real_y)
            prediction = d_y(fake_x)
            loss_g_y = bce_loss(prediction, real_label)
            y_hat = g_x(fake_x)
            loss_cycle_y = l1_loss(y_hat, real_y) * lambda_Y
            loss_y = loss_g_y + loss_cycle_y
            loss_y.backward(retain_graph=True)
            optimizer_g.step()
            loss_list_cycle_y.append(loss_y.item())
            # d_x
            optimizer_d.zero_grad()
            pred_real = d_x(real_y)
            loss_d_x_real = bce_loss(pred_real, real_label)
            pred_fake = d_x(fake_y)
            loss_d_x_fake = bce_loss(pred_fake, fake_label)
            loss_d_x = (loss_d_x_real + loss_d_x_fake) * 0.5
            loss_d_x.backward()
            loss_list_d_x.append(loss_d_x.item())
            optimizer_d.step()
            if 'WGAN' in FLAGS['gan_type']:
                for p in d_x.parameters():
                    p.data.clamp_(-1, 1)
            # d_y
            optimizer_d.zero_grad()
            pred_real = d_y(real_x)
            loss_d_y_real = bce_loss(pred_real, real_label)
            pred_fake = d_y(fake_x)
            loss_d_y_fake = bce_loss(pred_fake, fake_label)
            loss_d_y = (loss_d_y_real + loss_d_y_fake) * 0.5
            loss_d_y.backward()
            loss_list_d_y.append(loss_d_y.item())
            optimizer_d.step()
            if 'WGAN' in FLAGS['gan_type']:
                for p in d_y.parameters():
                    p.data.clamp_(-1, 1)

        if ep % FLAGS["save_interval"] == 0:

            with torch.no_grad():
                print(
                    " {} [Loss_G_X] {} [Loss_G_Y] {} [Loss_D_X] {} [Loss_D_Y] {}"
                    .format(ep, loss_list_g_x[-1], loss_list_g_y[-1],
                            loss_list_d_x[-1], loss_list_d_y[-1]))
                origin = cv2.imread("./test_data/obama_origin.jpg")
                gt_uv_map = np.load("./test_data/test_obama.npy")
                origin, gt_uv_map = test_data_preprocess(
                    origin), test_data_preprocess(gt_uv_map)

                origin, gt_uv_map = transform_img(origin), transform_img(
                    gt_uv_map)

                origin_in = origin.unsqueeze_(0).cuda()
                pred_uv_map = g_x(origin_in).detach().cpu()

                save_image(
                    [origin.cpu(),
                     gt_uv_map.unsqueeze_(0).cpu(), pred_uv_map],
                    os.path.join(FLAGS['images'],
                                 str(ep) + '.png'),
                    nrow=1,
                    normalize=True)

            # Save model
            print("Save model")
            state = {
                'g_x': g_x.state_dict(),
                'g_y': g_y.state_dict(),
                'd_x': d_x.state_dict(),
                'd_y': d_y.state_dict(),
                'start_epoch': ep,
            }
            torch.save(state, os.path.join(FLAGS['images'],
                                           '{}.pth'.format(ep)))

            scheduler.step()
예제 #13
0
class RefSRSolver(BaseSolver):
    def __init__(self, cfg):
        super(RefSRSolver, self).__init__(cfg)

        self.srntt = SRNTT(cfg['model']['n_resblocks'],
                           cfg['schedule']['use_weights'],
                           cfg['schedule']['concat']).cuda()
        # self.discriminator = None
        self.discriminator = Discriminator(cfg['data']['input_size']).cuda()
        # self.vgg = None
        self.vgg = VGG19(cfg['model']['final_layer'],
                         cfg['model']['prev_layer'], True).cuda()
        params = list(self.srntt.texture_transfer.parameters()) + list(self.srntt.texture_fusion_medium.parameters()) +\
                 list(self.srntt.texture_fusion_large.parameters()) + list(self.srntt.srntt_out.parameters())
        self.init_epoch = self.cfg['schedule']['init_epoch']
        self.num_epochs = self.cfg['schedule']['num_epochs']
        self.optimizer_init = torch.optim.Adam(params,
                                               lr=cfg['schedule']['lr'])
        self.optimizer = torch.optim.lr_scheduler.MultiStepLR(
            torch.optim.Adam(params, lr=cfg['schedule']['lr']),
            [self.num_epochs // 2], 0.1)
        self.optimizer_d = torch.optim.lr_scheduler.MultiStepLR(
            torch.optim.Adam(self.discriminator.parameters(),
                             lr=cfg['schedule']['lr']), [self.num_epochs // 2],
            0.1)
        self.reconst_loss = nn.L1Loss()
        self.bp_loss = BackProjectionLoss()
        self.texture_loss = TextureLoss(self.cfg['schedule']['use_weights'],
                                        80)
        self.adv_loss = AdvLoss(self.cfg['schedule']['is_WGAN_GP'])
        self.loss_weights = self.cfg['schedule']['loss_weights']

    def train(self):
        if self.epoch <= self.init_epoch:
            with tqdm(total=len(self.train_loader),
                      miniters=1,
                      desc='Initial Training Epoch: [{}/{}]'.format(
                          self.epoch, self.max_epochs)) as t:
                for data in self.train_loader:
                    lr, hr = data['lr'].cuda(), data['hr'].cuda()
                    maps, weight = data['map'].cuda(), data['weight'].cuda()
                    self.srntt.train()
                    self.optimizer_init.zero_grad()
                    sr, srntt_out = self.srntt(lr, weight, maps)
                    loss_reconst = self.reconst_loss(sr, hr)
                    loss_bp = self.bp_loss(lr, srntt_out)
                    loss = self.loss_weights[
                        4] * loss_reconst + self.loss_weights[3] * loss_bp
                    t.set_postfix_str("Batch loss {:.4f}".format(loss.item()))
                    t.update()

                    loss.backward()
                    self.optimizer_init.step()
        elif self.epoch <= self.num_epochs:
            with tqdm(total=len(self.train_loader),
                      miniters=1,
                      desc='Complete Training Epoch: [{}/{}]'.format(
                          self.epoch, self.max_epochs)) as t:
                for data in self.train_loader:
                    lr, hr = data['lr'].cuda(), data['hr'].cuda()
                    maps, weight = data['map'].cuda(), data['weight'].cuda()
                    self.srntt.train()
                    self.optimizer_init.zero_grad()
                    self.optimizer.optimizer.zero_grad()
                    self.optimizer_d.optimizer.zero_grad()
                    sr, srntt_out = self.srntt(lr, weight, maps)
                    sr_prevlayer, sr_lastlayer = self.vgg(srntt_out)
                    hr_prevlayer, hr_lastlayer = self.vgg(hr)
                    _, d_real_logits = self.discriminator(hr)
                    _, d_fake_logits = self.discriminator(srntt_out)
                    loss_reconst = self.reconst_loss(sr, hr)
                    loss_bp = self.bp_loss(lr, srntt_out)
                    loss_texture = self.texture_loss(sr_prevlayer, maps,
                                                     weight)
                    loss_d, loss_g = self.adv_loss(srntt_out, hr,
                                                   d_fake_logits,
                                                   d_real_logits,
                                                   self.discriminator)
                    loss_percep = torch.pow(sr_lastlayer - hr_lastlayer,
                                            2).mean()
                    if self.cfg['schedule']['use_lower_layers_in_per_loss']:
                        for l_sr, l_hr in zip(sr_prevlayer, hr_prevlayer):
                            loss_percep += torch.pow(l_sr - l_hr, 2).mean()
                        loss_percep = loss_percep / (len(sr_prevlayer) + 1)
                    weighted_loss = torch.Tensor(self.loss_weights).cuda() * \
                                    torch.Tensor([loss_percep, loss_texture, loss_g, loss_bp, loss_reconst])
                    total_loss = weighted_loss.sum()

                    t.set_postfix_str("Batch loss {:.4f}".format(
                        total_loss.item()))
                    t.update()

                    loss_d.backward()
                    total_loss.backward()
                    self.optimizer.step(self.epoch)
                    self.optimizer_d.step(self.epoch)
        else:
            pass

    def eval(self):
        with tqdm(total=len(self.val_loader),
                  miniters=1,
                  desc='Val Epoch: [{}/{}]'.format(self.epoch,
                                                   self.max_epochs)) as t:
            psnr_list, ssim_list, loss_list = [], [], []
            for lr, hr in self.val_loader:
                lr, hr = lr.cuda(), hr.cuda()
                self.srntt.eval()
                with torch.no_grad():
                    sr, _ = self.srntt(lr, None, None)
                    loss = self.reconst_loss(sr, hr)

                batch_psnr, batch_ssim = [], []
                for c in range(sr.shape[0]):
                    predict_sr = (sr[c, ...].cpu().numpy().transpose(
                        (1, 2, 0)) + 1) * 127.5
                    ground_truth = (hr[c, ...].cpu().numpy().transpose(
                        (1, 2, 0)) + 1) * 127.5
                    psnr = utils.calculate_psnr(predict_sr, ground_truth, 255)
                    ssim = utils.calculate_ssim(predict_sr, ground_truth, 255)
                    batch_psnr.append(psnr)
                    batch_ssim.append(ssim)
                avg_psnr = np.array(batch_psnr).mean()
                avg_ssim = np.array(batch_ssim).mean()
                psnr_list.extend(batch_psnr)
                ssim_list.extend(batch_ssim)
                t.set_postfix_str(
                    'Batch loss: {:.4f}, PSNR: {:.4f}, SSIM: {:.4f}'.format(
                        loss.item(), avg_psnr, avg_ssim))
                t.update()
            self.records['Epoch'].append(self.epoch)
            self.records['PSNR'].append(np.array(psnr_list).mean())
            self.records['SSIM'].append(np.array(ssim_list).mean())
            self.logger.log('Val Epoch {}: PSNR={}, SSIM={}'.format(
                self.epoch, self.records['PSNR'][-1],
                self.records['SSIM'][-1]))

    def save_checkpoint(self):
        super(RefSRSolver, self).save_checkpoint()
        self.ckp['srntt'] = self.srntt.state_dict()
        self.ckp['optimizer'] = self.optimizer.state_dict()
        self.ckp['optimizer_d'] = self.optimizer_d.state_dict()
        self.ckp['optimizer_init'] = self.optimizer_init.state_dict()
        if self.discriminator is not None:
            self.ckp['discriminator'] = self.discriminator.state_dict()
        if self.vgg is not None:
            self.ckp['vgg'] = self.vgg.state_dict()

        torch.save(self.ckp, os.path.join(self.checkpoint_dir, 'latest.pth'))
        if self.records['PSNR'][-1] == np.array(self.records['PSNR']).max():
            shutil.copy(os.path.join(self.checkpoint_dir, 'latest.pth'),
                        os.path.join(self.checkpoint_dir, 'best.pth'))

    def load_checkpoint(self, model_path):
        super(RefSRSolver, self).load_checkpoint(model_path)
        ckpt = torch.load(model_path)
        self.srntt.load_state_dict(ckpt['srntt'])
        self.optimizer.load_state_dict(ckpt['optimizer'])
        self.optimizer_d.load_state_dict(ckpt['optimizer_d'])
        self.optimizer_init.load_state_dict(ckpt['optimizer_init'])
        if 'vgg' in ckpt.keys() and self.vgg is not None:
            self.vgg.load_stat_dict(ckpt['srntt'])
        if 'discriminator' in ckpt.keys() and self.discriminator is not None:
            self.discriminator.load_state_dict(ckpt['discriminator'])
예제 #14
0
def main():

    # parse input size
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    # cudnn.enabled = True
    # gpu = args.gpu

    # create segmentation network
    model = DeepLab(num_classes=args.num_classes)

    # load pretrained parameters
    # if args.restore_from[:4] == 'http' :
    #     saved_state_dict = model_zoo.load_url(args.restore_from)
    # else:
    #     saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in current model (caffe-like)
    # new_params = model.state_dict().copy()
    # for name, param in new_params.items():
    #     if name in saved_state_dict and param.size() == saved_state_dict[name].size():
    #         new_params[name].copy_(saved_state_dict[name])
    # model.load_state_dict(new_params)

    model.train()
    model.cpu()
    # model.cuda(args.gpu)
    # cudnn.benchmark = True

    # create discriminator network
    model_D = Discriminator(num_classes=args.num_classes)
    # if args.restore_from_D is not None:
    #     model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D.train()
    model_D.cpu()
    # model_D.cuda(args.gpu)

    # MILESTONE 1
    print("Printing MODELS ...")
    print(model)
    print(model_D)

    # Create directory to save snapshots of the model
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    # Load train data and ground truth labels
    # train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size,
    #                 scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN)
    # train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size,
    #                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN)

    # trainloader = data.DataLoader(train_dataset,
    #                 batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=False)
    # trainloader_gt = data.DataLoader(train_gt_dataset,
    #                 batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=False)

    train_dataset = MyCustomDataset()
    train_gt_dataset = MyCustomDataset()

    trainloader = data.DataLoader(train_dataset, batch_size=5, shuffle=True)
    trainloader_gt = data.DataLoader(train_gt_dataset,
                                     batch_size=5,
                                     shuffle=True)

    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)

    # MILESTONE 2
    print("Printing Loaders")
    print(trainloader_iter)
    print(trainloader_gt_iter)

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    # MILESTONE 3
    print("Printing OPTIMIZERS ...")
    print(optimizer)
    print(optimizer_D)

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_value = 0
        loss_semi_adv_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # do semi first
            # if (args.lambda_semi > 0 or args.lambda_semi_adv > 0 ) and i_iter >= args.semi_start_adv :
            #     try:
            #         _, batch = next(trainloader_remain_iter)
            #     except:
            #         trainloader_remain_iter = enumerate(trainloader_remain)
            #         _, batch = next(trainloader_remain_iter)

            #     # only access to img
            #     images, _, _, _ = batch
            #     images = Variable(images).cuda(args.gpu)

            #     pred = interp(model(images))
            #     pred_remain = pred.detach()

            #     D_out = interp(model_D(F.softmax(pred)))
            #     D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(axis=1)

            #     ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(np.bool)

            #     loss_semi_adv = args.lambda_semi_adv * bce_loss(D_out, make_D_label(gt_label, ignore_mask_remain))
            #     loss_semi_adv = loss_semi_adv/args.iter_size

            #     #loss_semi_adv.backward()
            #     loss_semi_adv_value += loss_semi_adv.data.cpu().numpy()/args.lambda_semi_adv

            #     if args.lambda_semi <= 0 or i_iter < args.semi_start:
            #         loss_semi_adv.backward()
            #         loss_semi_value = 0
            #     else:
            #         # produce ignore mask
            #         semi_ignore_mask = (D_out_sigmoid < args.mask_T)

            #         semi_gt = pred.data.cpu().numpy().argmax(axis=1)
            #         semi_gt[semi_ignore_mask] = 255

            #         semi_ratio = 1.0 - float(semi_ignore_mask.sum())/semi_ignore_mask.size
            #         print('semi ratio: {:.4f}'.format(semi_ratio))

            #         if semi_ratio == 0.0:
            #             loss_semi_value += 0
            #         else:
            #             semi_gt = torch.FloatTensor(semi_gt)

            #             loss_semi = args.lambda_semi * loss_calc(pred, semi_gt, args.gpu)
            #             loss_semi = loss_semi/args.iter_size
            #             loss_semi_value += loss_semi.data.cpu().numpy()/args.lambda_semi
            #             loss_semi += loss_semi_adv
            #             loss_semi.backward()

            # else:
            #     loss_semi = None
            #     loss_semi_adv = None

            # train with source

            try:
                _, batch = next(trainloader_iter)
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = next(trainloader_iter)

            images, labels, _, _ = batch
            images = Variable(images).cpu()
            # images = Variable(images).cuda(args.gpu)
            ignore_mask = (labels.numpy() == 255)

            # segmentation prediction
            pred = interp(model(images))
            # (spatial multi-class) cross entropy loss
            loss_seg = loss_calc(pred, labels)
            # loss_seg = loss_calc(pred, labels, args.gpu)

            # discriminator prediction
            D_out = interp(model_D(F.softmax(pred)))
            # adversarial loss
            loss_adv_pred = bce_loss(D_out,
                                     make_D_label(gt_label, ignore_mask))

            # multi-task loss
            # lambda_adv - weight for minimizing loss
            loss = loss_seg + args.lambda_adv_pred * loss_adv_pred

            # loss normalization
            loss = loss / args.iter_size

            # back propagation
            loss.backward()

            loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size
            loss_adv_pred_value += loss_adv_pred.data.cpu().numpy(
            ) / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D.parameters():
                param.requires_grad = True

            # train with pred
            pred = pred.detach()

            # if args.D_remain:
            #     pred = torch.cat((pred, pred_remain), 0)
            #     ignore_mask = np.concatenate((ignore_mask,ignore_mask_remain), axis = 0)

            D_out = interp(model_D(F.softmax(pred)))
            loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()

            # train with gt
            # get gt labels
            try:
                _, batch = next(trainloader_gt_iter)
            except:
                trainloader_gt_iter = enumerate(trainloader_gt)
                _, batch = next(trainloader_gt_iter)

            _, labels_gt, _, _ = batch
            D_gt_v = Variable(one_hot(labels_gt)).cpu()
            # D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)
            ignore_mask_gt = (labels_gt.numpy() == 255)

            D_out = interp(model_D(D_gt_v))
            loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()

        optimizer.step()
        optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_adv_pred_value, loss_D_value, loss_semi_value,
                    loss_semi_adv_value))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')