Beispiel #1
0
def main():
    opt = get_opt()

    # Define the Generators, only G_A is used for testing
    netG_A = networks.Generator(opt.input_nc, opt.output_nc, opt.ngf, opt.n_res, opt.dropout)
    if opt.u_net:
        netG_A = networks.U_net(opt.input_nc, opt.output_nc, opt.ngf)

    if opt.cuda:
        netG_A.cuda()
    # Do not need to track the gradients during testing
    utils.set_requires_grad(netG_A, False)
    netG_A.eval()
    netG_A.load_state_dict(torch.load(opt.net_GA))

    # Load the data
    transform = transforms.Compose([transforms.Resize((opt.sizeh, opt.sizew)),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (0.5,))])
    dataloader = DataLoader(ImageDataset(opt.rootdir, transform=transform, mode='val'), batch_size=opt.batch_size,
                            shuffle=False, num_workers=opt.n_cpu)

    Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
    input_A = Tensor(opt.batch_size, opt.input_nc, opt.size, opt.size)

    for i, batch in enumerate(dataloader):
        name, image = batch
        real_A = input_A.copy_(image)
        fake_B = netG_A(real_A)
        batch_size = len(name)
        # Save the generated images
        for j in range(batch_size):
            image_name = name[j].split('/')[-1]
            path = 'generated_image/' + image_name
            utils.save_image(fake_B[j, :, :, :], path)
Beispiel #2
0
    def train(self):
        # load training info
        filepath = self.args.checkpoint
        if filepath is not None:
            self._load(filepath)

        while self.epoch < self.args.epochs:
            self.model.train()
            self.epoch += 1
            for iteration, batch in enumerate(self.train_dataloader):
                real_a, real_b = batch[0].to(self.device), batch[1].to(self.device)
                fake_b = self.generator(real_a)

                # update discriminator
                set_requires_grad(self.discriminator, True)
                self.optimizer_d.zero_grad()

                fake_ab = torch.cat((real_a, fake_b), 1)
                pred_fake = self.discriminator(fake_ab.detach())
                loss_fake = self.criterion_gan(pred_fake, False)

                real_ab = torch.cat((real_a, real_b), 1)
                pred_real = self.discriminator(real_ab)
                loss_real = self.criterion_gan(pred_real, True)

                loss_d = (loss_fake + loss_real)/2
                loss_d.backward()
                self.optimizer_d.step()

                # update generator
                set_requires_grad(self.discriminator, False)
                self.optimizer_g.zero_grad()

                fake_ab = torch.cat((real_a, fake_b), 1)
                pred_fake = self.discriminator(fake_ab)
                loss_gan = self.criterion_gan(pred_fake, True)
                loss_l1 = self.criterion_l1(fake_b, real_b) * self.args.lamb

                loss_g = loss_gan + loss_l1
                loss_g.backward()
                self.optimizer_g.step()

                if (iteration+1)%self.args.print_loss_freq == 0:
                    print('Epoch[{0}]({1}/{2} - Loss_D: {3}, Loss_G: {4}, Loss: {5}'.format(
                        self.epoch,
                        iteration+1,
                        len(self.train_dataloader),
                        loss_d.item(),
                        loss_g.item(),
                        loss_d.item()+loss_g.item()
                    ))
            
            self.scheduler_g.step()
            self.scheduler_d.step()

            if self.epoch % self.args.save_freq == 0:
                self._validate()
                self._save(str(self.epoch))
                self._save('last')
Beispiel #3
0
    def optimize_parameters(self):
        self.forward()
        set_requires_grad(self.net, True)
        self.optimizer.zero_grad()
        self.loss_calculate()

        self.loss_all.backward()
        self.optimizer.step()
    def run(self):
        iter_A = iter(self.dataloader_A)
        iter_B = iter(self.dataloader_B)
        iter_per_epoch = min(len(iter_A), len(iter_B))

        for step in range(self.start_iter, self.start_iter + self.train_iters):
            if step % iter_per_epoch == 0:
                iter_A = iter(self.dataloader_A)
                iter_B = iter(self.dataloader_B)

            real_A = iter_A.next()
            real_B = iter_B.next()
            real_A, real_B = real_A.to(self.device), real_B.to(self.device)
            fake_B = self.G_AB(real_A)
            fake_A = self.G_BA(real_B)

            # train G
            set_requires_grad([self.D_A, self.D_B], False)
            self.optimizer_G.zero_grad()
            loss_G_AB, loss_G_BA, loss_cycle_A, loss_cycle_B, loss_idt_A, loss_idt_B = (  # noqa
                self.backward_G(real_A, real_B, fake_A, fake_B))
            self.optimizer_G.step()

            # train D
            set_requires_grad([self.D_A, self.D_B], True)
            self.optimizer_D.zero_grad()
            loss_D_A = self.backward_D(self.D_A, real_A, fake_A)
            loss_D_B = self.backward_D(self.D_B, real_B, fake_B)
            self.optimizer_D.step()

            # logger
            if step % self.args.log_report_freq == 0:
                self.logger.info(
                    '{} Train: Step {} Loss/G_AB {:.4f}'.format(
                        self.args.exp_name, step, loss_G_AB))

            # writer
            if step % self.args.scalar_report_freq == 0:
                self.writer.add_scalar('Train/Loss/G_AB', loss_G_AB, step)
                self.writer.add_scalar('Train/Loss/cycle_A', loss_cycle_A, step)
                self.writer.add_scalar('Train/Loss/idt_A', loss_idt_A, step)
                self.writer.add_scalar('Train/Loss/G_BA', loss_G_BA, step)
                self.writer.add_scalar('Train/Loss/cycle_B', loss_cycle_B, step)
                self.writer.add_scalar('Train/Loss/idt_B', loss_idt_B, step)
                self.writer.add_scalar('Train/Loss/D_A', loss_D_A, step)
                self.writer.add_scalar('Train/Loss/D_B', loss_D_B, step)

            if step % self.args.image_report_freq == 0:
                real_A = real_A[-1].detach().cpu()
                real_B = real_B[-1].detach().cpu()
                fake_A = fake_A[-1].detach().cpu()
                fake_B = fake_B[-1].detach().cpu()
                self.writer.add_image('Train/real_A', real_A, step)
                self.writer.add_image('Train/fake_A', fake_A, step)
                self.writer.add_image('Train/real_B', real_B, step)
                self.writer.add_image('Train/fake_B', fake_B, step)
Beispiel #5
0
def execute(all_models, envname, savepath, eval_hp=None):
    hp = get_hp(eval_hp)
    ### Inverse Model Evaluation ###
    # Set all models to eval
    for m in all_models:
        set_requires_grad(m, False)
        m.eval()

    # Get models
    model, c_model, actor = all_models

    # Set Env the viewer config
    config = {
        "visible": False,
        "init_width": hp["init_width"],
        "init_height": hp["init_height"],
        "go_fast": True
    }
    env = get_env(envname)
    env.reset()
    env.render(config=config)
    env.viewer_setup()
    for i in range(100):
        env.render(config=config)
    test_mode = hp["test_mode"]
    n_test_locs = get_n_test_locs(envname) if test_mode == "valid" else hp["n_test_locs"]

    # Generate test data
    test_data = []
    for i in range(n_test_locs):
        if test_mode == "valid":
            test_data.append(set_test_loc(i, envname, env, config))
        else:
            test_data.append(set_valid_loc(envname, env, config))

    ### Planning ###
    if hp["run_inverse_model"]:
        run_inverse_model(env, config, test_data, n_test_locs, hp, actor)

    total_dist, total_success, n_test_locs = run_planning_and_inverse_model(env, envname, hp, n_test_locs, test_data,
                                                                            config, test_mode, model, c_model, actor,
                                                                            savepath)
    return total_dist, total_success, n_test_locs
Beispiel #6
0
    def optimize_D(self):
        """Calculate GAN loss for the discriminator"""

        utils.set_requires_grad(self.D, True)
        self.optimizer_D.zero_grad()

        fake_D_in = torch.cat((self.real_in, self.fake_out), 1)
        pred_fake = self.D(fake_D_in.detach())
        self.loss_D_fake = self.criterionGAN(pred_fake, False)

        real_D_in = torch.cat((self.real_in, self.real_out), 1)
        pred_real = self.D(real_D_in)
        self.loss_D_real = self.criterionGAN(pred_real, True)

        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()

        self.optimizer_D.step()

        utils.set_requires_grad(self.D, False)
Beispiel #7
0
def main(args):
    source_model = Net().to(device)
    source_model.load_state_dict(torch.load(args.MODEL_FILE))
    source_model.eval()
    set_requires_grad(source_model, requires_grad=False)

    clf = source_model
    source_model = source_model.feature_extractor

    target_model = Net().to(device)
    target_model.load_state_dict(torch.load(args.MODEL_FILE))
    target_model = target_model.feature_extractor
    target_clf = clf.classifier

    discriminator = nn.Sequential(nn.Linear(320, 50), nn.ReLU(),
                                  nn.Linear(50, 20), nn.ReLU(),
                                  nn.Linear(20, 1)).to(device)

    half_batch = args.batch_size // 2
    source_dataset = MNIST(config.DATA_DIR / 'mnist',
                           train=True,
                           download=True,
                           transform=Compose([GrayscaleToRgb(),
                                              ToTensor()]))
    source_loader = DataLoader(source_dataset,
                               batch_size=half_batch,
                               shuffle=True,
                               num_workers=1,
                               pin_memory=True)

    target_dataset = MNISTM(train=False)
    target_loader = DataLoader(target_dataset,
                               batch_size=half_batch,
                               shuffle=True,
                               num_workers=1,
                               pin_memory=True)

    discriminator_optim = torch.optim.Adam(discriminator.parameters())
    target_optim = torch.optim.Adam(target_model.parameters())
    criterion = nn.BCEWithLogitsLoss()

    for epoch in range(1, args.epochs + 1):
        batch_iterator = zip(loop_iterable(source_loader),
                             loop_iterable(target_loader))

        total_loss = 0
        total_accuracy = 0
        target_label_accuracy = 0
        for _ in trange(args.iterations, leave=False):
            # Train discriminator
            set_requires_grad(target_model, requires_grad=False)
            set_requires_grad(discriminator, requires_grad=True)
            for _ in range(args.k_disc):
                (source_x, _), (target_x, _) = next(batch_iterator)
                source_x, target_x = source_x.to(device), target_x.to(device)

                source_features = source_model(source_x).view(
                    source_x.shape[0], -1)
                target_features = target_model(target_x).view(
                    target_x.shape[0], -1)

                discriminator_x = torch.cat([source_features, target_features])
                discriminator_y = torch.cat([
                    torch.ones(source_x.shape[0], device=device),
                    torch.zeros(target_x.shape[0], device=device)
                ])

                preds = discriminator(discriminator_x).squeeze()
                loss = criterion(preds, discriminator_y)

                discriminator_optim.zero_grad()
                loss.backward()
                discriminator_optim.step()

                total_loss += loss.item()
                total_accuracy += ((
                    preds >
                    0).long() == discriminator_y.long()).float().mean().item()

            # Train classifier
            set_requires_grad(target_model, requires_grad=True)
            set_requires_grad(discriminator, requires_grad=False)
            for _ in range(args.k_clf):
                _, (target_x, target_labels) = next(batch_iterator)
                target_x = target_x.to(device)
                target_features = target_model(target_x).view(
                    target_x.shape[0], -1)

                # flipped labels
                discriminator_y = torch.ones(target_x.shape[0], device=device)

                preds = discriminator(target_features).squeeze()
                loss = criterion(preds, discriminator_y)

                target_optim.zero_grad()
                loss.backward()
                target_optim.step()

                target_label_preds = target_clf(target_features)
                target_label_accuracy += (target_label_preds.cpu().max(1)[1] ==
                                          target_labels).float().mean().item()

        mean_loss = total_loss / (args.iterations * args.k_disc)
        mean_accuracy = total_accuracy / (args.iterations * args.k_disc)
        target_mean_accuracy = target_label_accuracy / (args.iterations *
                                                        args.k_clf)
        tqdm.write(
            f'EPOCH {epoch:03d}: discriminator_loss={mean_loss:.4f}, '
            f'discriminator_accuracy={mean_accuracy:.4f}, target_accuracy={target_mean_accuracy:.4f}'
        )

        # Create the full target model and save it
        clf.feature_extractor = target_model
        torch.save(clf.state_dict(), 'trained_models/adda.pt')
def main(args):
    clf_model = Net().to(device)
    clf_model.load_state_dict(torch.load(args.MODEL_FILE))

    feature_extractor = clf_model.feature_extractor
    discriminator = clf_model.classifier

    critic = nn.Sequential(nn.Linear(320, 50), nn.ReLU(), nn.Linear(50, 20),
                           nn.ReLU(), nn.Linear(20, 1)).to(device)

    half_batch = args.batch_size // 2
    source_dataset = MNIST(config.DATA_DIR / 'mnist',
                           train=True,
                           download=True,
                           transform=Compose([GrayscaleToRgb(),
                                              ToTensor()]))
    source_loader = DataLoader(source_dataset,
                               batch_size=half_batch,
                               drop_last=True,
                               shuffle=True,
                               num_workers=0,
                               pin_memory=True)

    target_dataset = MNISTM(train=False)
    target_loader = DataLoader(target_dataset,
                               batch_size=half_batch,
                               drop_last=True,
                               shuffle=True,
                               num_workers=0,
                               pin_memory=True)

    critic_optim = torch.optim.Adam(critic.parameters(), lr=1e-4)
    clf_optim = torch.optim.Adam(clf_model.parameters(), lr=1e-4)
    clf_criterion = nn.CrossEntropyLoss()

    for epoch in range(1, args.epochs + 1):
        batch_iterator = zip(loop_iterable(source_loader),
                             loop_iterable(target_loader))

        total_loss = 0
        total_accuracy = 0
        for _ in trange(args.iterations, leave=False):
            (source_x, source_y), (target_x, _) = next(batch_iterator)
            # Train critic
            set_requires_grad(feature_extractor, requires_grad=False)
            set_requires_grad(critic, requires_grad=True)

            source_x, target_x = source_x.to(device), target_x.to(device)
            source_y = source_y.to(device)

            with torch.no_grad():
                h_s = feature_extractor(source_x).data.view(
                    source_x.shape[0], -1)
                h_t = feature_extractor(target_x).data.view(
                    target_x.shape[0], -1)
            for _ in range(args.k_critic):
                gp = gradient_penalty(critic, h_s, h_t)

                critic_s = critic(h_s)
                critic_t = critic(h_t)
                wasserstein_distance = critic_s.mean() - critic_t.mean()

                critic_cost = -wasserstein_distance + args.gamma * gp

                critic_optim.zero_grad()
                critic_cost.backward()
                critic_optim.step()

                total_loss += critic_cost.item()

            # Train classifier
            set_requires_grad(feature_extractor, requires_grad=True)
            set_requires_grad(critic, requires_grad=False)
            for _ in range(args.k_clf):
                source_features = feature_extractor(source_x).view(
                    source_x.shape[0], -1)
                target_features = feature_extractor(target_x).view(
                    target_x.shape[0], -1)

                source_preds = discriminator(source_features)
                clf_loss = clf_criterion(source_preds, source_y)
                wasserstein_distance = critic(source_features).mean() - critic(
                    target_features).mean()

                loss = clf_loss + args.wd_clf * wasserstein_distance
                clf_optim.zero_grad()
                loss.backward()
                clf_optim.step()

        mean_loss = total_loss / (args.iterations * args.k_critic)
        tqdm.write(f'EPOCH {epoch:03d}: critic_loss={mean_loss:.4f}')
        torch.save(clf_model.state_dict(), 'trained_models/wdgrl.pt')
Beispiel #9
0

if __name__ == '__main__':
    args = parse_args()

    img_list = open(args.input_list, 'r').read().strip().split('\n')
    dataset = TestSaveDataset(img_list)
    dataloader = DataLoader(dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            drop_last=False,
                            num_workers=8,
                            pin_memory=True)

    net = SRNDeblurNet().cuda()
    set_requires_grad(net, False)
    last_epoch = load_model(net, args.resume, epoch=args.resume_epoch)

    psnr_list = []
    output_list = []

    tt = time()
    for step, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
        for k in batch:
            if 'img' in k:
                batch[k] = batch[k].cuda()
                batch[k].requires_grad = False

        y, _, _ = net(batch['img256'], batch['img128'], batch['img64'])

        y.detach_()
Beispiel #10
0
def main(args):
    if args.model == 'gta':
        model_file = './trained_models/gta_source.pt'
        out_file = './trained_models/gta_wdgrl.pt'
        out_ftrs = 4375

        clf_model = GTANet().to(device)
        clf_model.load_state_dict(torch.load(model_file))
        feature_extractor = clf_model.feature_extractor
        discriminator = clf_model.classifier

    elif args.model == 'gta-res':
        model_file = './trained_models/gta_res_source.pt'
        out_file = './trained_models/gta_res_wdgrl.pt'

        clf_model = GTARes18Net(9, pretrained=False).to(device)
        out_ftrs = clf_model.fc.in_features
        clf_model.load_state_dict(torch.load(model_file))

        feature_extractor = clf_model.feature_extractor
        discriminator = clf_model.fc

    elif args.model == 'gta-vgg':
        model_file = './trained_models/gta_vgg_source.pt'
        out_file = './trained_models/gta_vgg_wdgrl.pt'

        clf_model = GTAVGG11Net(9, pretrained=False).to(device)
        out_ftrs = clf_model.classifier[0].in_features  # should be 512 * 7 * 7
        clf_model.load_state_dict(torch.load(model_file))
        set_requires_grad(clf_model, False)

        feature_extractor = clf_model.feature_extractor
        discriminator = clf_model.classifier

    else:
        raise ValueError(f'Unknown model type {args.model}')

    critic = nn.Sequential(
        nn.Linear(out_ftrs, 64),
        nn.ReLU(),
        nn.Linear(64, 16),
        nn.ReLU(),
        nn.Linear(16, 1),
    ).to(device)

    half_batch = args.batch_size // 2
    target_dataset = ImageFolder('./data',
                                 transform=Compose([
                                     Resize((398, 224)),
                                     RandomCrop(224),
                                     RandomHorizontalFlip(),
                                     ToTensor(),
                                     Normalize([0.485, 0.456, 0.406],
                                               [0.229, 0.224, 0.225]),
                                 ]))
    target_loader = DataLoader(target_dataset,
                               batch_size=half_batch,
                               shuffle=True,
                               num_workers=1,
                               pin_memory=True)

    source_dataset = ImageFolder('./t_data',
                                 transform=Compose([
                                     RandomCrop(224,
                                                pad_if_needed=True,
                                                padding_mode='reflect'),
                                     RandomHorizontalFlip(),
                                     ToTensor(),
                                     Normalize([0.485, 0.456, 0.406],
                                               [0.229, 0.224, 0.225]),
                                 ]))
    source_loader = DataLoader(source_dataset,
                               batch_size=half_batch,
                               shuffle=True,
                               num_workers=1,
                               pin_memory=True)

    critic_optim = torch.optim.Adam(critic.parameters(), lr=1e-4)
    clf_optim = torch.optim.Adam(clf_model.parameters(), lr=1e-4)
    clf_criterion = nn.CrossEntropyLoss()

    for epoch in range(1, args.epochs + 1):
        batch_iterator = zip(loop_iterable(source_loader),
                             loop_iterable(target_loader))

        total_loss = 0
        for _ in trange(args.iterations, leave=False):
            (source_x, source_y), (target_x, _) = next(batch_iterator)
            # Train critic
            set_requires_grad(feature_extractor, requires_grad=False)
            set_requires_grad(critic, requires_grad=True)

            source_x, target_x = source_x.to(device), target_x.to(device)
            source_y = source_y.to(device)

            with torch.no_grad():
                h_s = feature_extractor(source_x).data.view(
                    source_x.shape[0], -1)
                h_t = feature_extractor(target_x).data.view(
                    target_x.shape[0], -1)
            for _ in range(args.k_critic):
                gp = gradient_penalty(critic, h_s, h_t)

                critic_s = critic(h_s)
                critic_t = critic(h_t)
                wasserstein_distance = critic_s.mean() - critic_t.mean()

                critic_cost = -wasserstein_distance + args.gamma * gp

                critic_optim.zero_grad()
                critic_cost.backward()
                critic_optim.step()

                total_loss += critic_cost.item()

            # Train classifier
            set_requires_grad(feature_extractor, requires_grad=True)
            set_requires_grad(critic, requires_grad=False)
            for _ in range(args.k_clf):
                source_features = feature_extractor(source_x).view(
                    source_x.shape[0], -1)
                target_features = feature_extractor(target_x).view(
                    target_x.shape[0], -1)

                source_preds = discriminator(source_features)
                clf_loss = clf_criterion(source_preds, source_y)
                wasserstein_distance = critic(source_features).mean() - critic(
                    target_features).mean()

                loss = clf_loss + args.wd_clf * wasserstein_distance
                clf_optim.zero_grad()
                loss.backward()
                clf_optim.step()

        mean_loss = total_loss / (args.iterations * args.k_critic)
        tqdm.write(f'EPOCH {epoch:03d}: critic_loss={mean_loss:.4f}')
        torch.save(clf_model.state_dict(), out_file)
    posterior.batch_size = opt.batch_size
    prior.batch_size = opt.batch_size
    opt.g_dim = tmp['opt'].g_dim
    opt.z_dim = tmp['opt'].z_dim
    opt.num_digits = tmp['opt'].num_digits

    # --------- transfer to gpu ------------------------------------
    frame_predictor.cuda()
    posterior.cuda()
    prior.cuda()
    encoder.cuda()
    decoder.cuda()

    nets = [frame_predictor, posterior, prior, encoder, decoder]
    # ---------------- discriminator ----------
    utils.set_requires_grad(nets, False)

#-------- load DSVG
elif opt.svg_name == 'DSVG':
    opt.model_path = 'logs/lp/smmnist-2/drsvg=dcgan64x64-n_past=5-n_future=10-z^p_dim=8-z^c_dim=64-last_frame_skip=False-beta=0.000100-random_seqs=True'
    
    tmp = torch.load('%s/opt.pth' % opt.model_path)
    import models.lstm as lstm_models
    if opt.model == 'dcgan':
        if opt.image_width == 64:
            import models.dcgan_64 as model 
        elif opt.image_width == 128:
            import models.dcgan_128 as model  
    elif opt.model == 'vgg':
        if opt.image_width == 64:
            import models.vgg_64 as model
def train():

    # Fix Seed for Reproducibility #
    torch.manual_seed(9)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(9)

    # Samples, Weights and Results Path #
    paths = [config.samples_path, config.weights_path, config.plots_path]
    paths = [make_dirs(path) for path in paths]

    # Prepare Data Loader #
    train_horse_loader, train_zebra_loader = get_horse2zebra_loader('train', config.batch_size)
    val_horse_loader, val_zebra_loader = get_horse2zebra_loader('test', config.batch_size)
    total_batch = min(len(train_horse_loader), len(train_zebra_loader))

    # Image Pool #
    masked_fake_A_pool = ImageMaskPool(config.pool_size)
    masked_fake_B_pool = ImageMaskPool(config.pool_size)

    # Prepare Networks #
    Attn_A = Attention()
    Attn_B = Attention()
    G_A2B = Generator()
    G_B2A = Generator()
    D_A = Discriminator()
    D_B = Discriminator()

    networks = [Attn_A, Attn_B, G_A2B, G_B2A, D_A, D_B]
    for network in networks:
        network.to(device)

    # Loss Function #
    criterion_Adversarial = nn.MSELoss()
    criterion_Cycle = nn.L1Loss()

    # Optimizers #
    D_optim = torch.optim.Adam(chain(D_A.parameters(), D_B.parameters()), lr=config.lr, betas=(0.5, 0.999))
    G_optim = torch.optim.Adam(chain(Attn_A.parameters(), Attn_B.parameters(), G_A2B.parameters(), G_B2A.parameters()), lr=config.lr, betas=(0.5, 0.999))

    D_optim_scheduler = get_lr_scheduler(D_optim)
    G_optim_scheduler = get_lr_scheduler(G_optim)

    # Lists #
    D_A_losses, D_B_losses = [], []
    G_A_losses, G_B_losses = [], []

    # Train #
    print("Training Unsupervised Attention-Guided GAN started with total epoch of {}.".format(config.num_epochs))

    for epoch in range(config.num_epochs):

        for i, (real_A, real_B) in enumerate(zip(train_horse_loader, train_zebra_loader)):

            # Data Preparation #
            real_A = real_A.to(device)
            real_B = real_B.to(device)

            # Initialize Optimizers #
            D_optim.zero_grad()
            G_optim.zero_grad()

            ###################
            # Train Generator #
            ###################

            set_requires_grad([D_A, D_B], requires_grad=False)

            # Adversarial Loss using real A #
            attn_A = Attn_A(real_A)
            fake_B = G_A2B(real_A)

            masked_fake_B = fake_B * attn_A + real_A * (1-attn_A)

            masked_fake_B *= attn_A
            prob_real_A = D_A(masked_fake_B)
            real_labels = torch.ones(prob_real_A.size()).to(device)

            G_loss_A = criterion_Adversarial(prob_real_A, real_labels)

            # Adversarial Loss using real B #
            attn_B = Attn_B(real_B)
            fake_A = G_B2A(real_B)

            masked_fake_A = fake_A * attn_B + real_B * (1-attn_B)

            masked_fake_A *= attn_B
            prob_real_B = D_B(masked_fake_A)
            real_labels = torch.ones(prob_real_B.size()).to(device)

            G_loss_B = criterion_Adversarial(prob_real_B, real_labels)

            # Cycle Consistency Loss using real A #
            attn_ABA = Attn_B(masked_fake_B)
            fake_ABA = G_B2A(masked_fake_B)
            masked_fake_ABA = fake_ABA * attn_ABA + masked_fake_B * (1 - attn_ABA)

            # Cycle Consistency Loss using real B #
            attn_BAB = Attn_A(masked_fake_A)
            fake_BAB = G_A2B(masked_fake_A)
            masked_fake_BAB = fake_BAB * attn_BAB + masked_fake_A * (1 - attn_BAB)

            # Cycle Consistency Loss #
            G_cycle_loss_A = config.lambda_cycle * criterion_Cycle(masked_fake_ABA, real_A)
            G_cycle_loss_B = config.lambda_cycle * criterion_Cycle(masked_fake_BAB, real_B)

            # Total Generator Loss #
            G_loss = G_loss_A + G_loss_B + G_cycle_loss_A + G_cycle_loss_B

            # Back Propagation and Update #
            G_loss.backward()
            G_optim.step()

            #######################
            # Train Discriminator #
            #######################

            set_requires_grad([D_A, D_B], requires_grad=True)

            # Train Discriminator A using real A #
            prob_real_A = D_A(real_B)
            real_labels = torch.ones(prob_real_A.size()).to(device)
            D_loss_real_A = criterion_Adversarial(prob_real_A, real_labels)

            # Add Pooling #
            masked_fake_B, attn_A = masked_fake_B_pool.query(masked_fake_B, attn_A)
            masked_fake_B *= attn_A

            # Train Discriminator A using fake B #
            prob_fake_B = D_A(masked_fake_B.detach())
            fake_labels = torch.zeros(prob_fake_B.size()).to(device)
            D_loss_fake_A = criterion_Adversarial(prob_fake_B, fake_labels)

            D_loss_A = (D_loss_real_A + D_loss_fake_A).mean()

            # Train Discriminator B using real B #
            prob_real_B = D_B(real_A)
            real_labels = torch.ones(prob_real_B.size()).to(device)
            D_loss_real_B = criterion_Adversarial(prob_real_B, real_labels)

            # Add Pooling #
            masked_fake_A, attn_B = masked_fake_A_pool.query(masked_fake_A, attn_B)
            masked_fake_A *= attn_B

            # Train Discriminator B using fake A #
            prob_fake_A = D_B(masked_fake_A.detach())
            fake_labels = torch.zeros(prob_fake_A.size()).to(device)
            D_loss_fake_B = criterion_Adversarial(prob_fake_A, fake_labels)

            D_loss_B = (D_loss_real_B + D_loss_fake_B).mean()

            # Calculate Total Discriminator Loss #
            D_loss = D_loss_A + D_loss_B

            # Back Propagation and Update #
            D_loss.backward()
            D_optim.step()

            # Add items to Lists #
            D_A_losses.append(D_loss_A.item())
            D_B_losses.append(D_loss_B.item())
            G_A_losses.append(G_loss_A.item())
            G_B_losses.append(G_loss_B.item())

            ####################
            # Print Statistics #
            ####################

            if (i+1) % config.print_every == 0:
                print("UAG-GAN | Epoch [{}/{}] | Iteration [{}/{}] | D A Losses {:.4f} | D B Losses {:.4f} | G A Losses {:.4f} | G B Losses {:.4f}".
                      format(epoch+1, config.num_epochs, i+1, total_batch, np.average(D_A_losses), np.average(D_B_losses), np.average(G_A_losses), np.average(G_B_losses)))

                # Save Sample Images #
                save_samples(val_horse_loader, val_zebra_loader, G_A2B, G_B2A, Attn_A, Attn_B, epoch, config.samples_path)

        # Adjust Learning Rate #
        D_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Model Weights #
        if (epoch + 1) % config.save_every == 0:
            torch.save(G_A2B.state_dict(), os.path.join(config.weights_path, 'UAG-GAN_Generator_A2B_Epoch_{}.pkl'.format(epoch+1)))
            torch.save(G_B2A.state_dict(), os.path.join(config.weights_path, 'UAG-GAN_Generator_B2A_Epoch_{}.pkl'.format(epoch+1)))
            torch.save(Attn_A.state_dict(), os.path.join(config.weights_path, 'UAG-GAN_Attention_A_Epoch_{}.pkl'.format(epoch+1)))
            torch.save(Attn_B.state_dict(), os.path.join(config.weights_path, 'UAG-GAN_Attention_B_Epoch_{}.pkl'.format(epoch+1)))

    # Make a GIF file #
    make_gifs_train("UAG-GAN", config.samples_path)

    # Plot Losses #
    plot_losses(D_A_losses, D_B_losses, G_A_losses, G_B_losses, config.num_epochs, config.plots_path)

    print("Training finished.")
Beispiel #13
0
                              lr=args.lr,
                              weight_decay=args.weight_decay)
 sched = lr_scheduler.StepLR(optimizer, step_size=50)
 ''' setup tensorboard '''
 writer = SummaryWriter(os.path.join(args.save_dir, 'train_info'))
 ''' load train and val features '''
 print('===> load train and val features and labels ...')
 train_features, train_label, valid_features, valid_labels = utils.load_features(
     args)
 ''' train model '''
 print('===> start training ...')
 iters = 0
 best_acc = 0
 for epoch in range(1, args.epoch + 1):
     FC.train()
     utils.set_requires_grad(FC, True)
     total_length = train_features.shape[0]
     perm_index = torch.randperm(total_length)
     train_X_sfl = train_features[perm_index]
     train_y_sfl = train_label[perm_index]
     # construct training batch
     for index in range(0, total_length, args.train_batch):
         train_info = 'Epoch: [{0}][{1}/{2}]'.format(
             epoch, index + 1, len(train_loader))
         iters += 1
         optimizer.zero_grad()
         if index + args.train_batch > total_length:
             #break
             input_X = train_X_sfl[index:]
             input_y = train_y_sfl[index:]
         else:
Beispiel #14
0
                              lr=args.lr,
                              weight_decay=args.weight_decay)
 sched = lr_scheduler.StepLR(optimizer, step_size=50)
 ''' setup tensorboard '''
 writer = SummaryWriter(os.path.join(args.save_dir, 'train_info'))
 ''' load train and val features '''
 print('===> load train and val features and labels ...')
 train_features, train_labels, valid_features, valid_labels = utils.load_features(
     args)
 ''' train model '''
 print('===> start training ...')
 iters = 0
 best_acc = 0
 for epoch in range(1, args.epoch + 1):
     BiRNN.train()
     utils.set_requires_grad(BiRNN, True)
     total_length = train_features.shape[0]
     # shuffle
     perm_index = np.random.permutation(len(train_features))
     train_X_sfl = [train_features[i] for i in perm_index]
     train_y_sfl = np.array(train_labels)[perm_index]
     # construct training batch
     for index in range(0, total_length, args.train_batch):
         train_info = 'Epoch: [{0}][{1}/{2}]'.format(
             epoch, index + 1, len(train_loader))
         iters += 1
         optimizer.zero_grad()
         if index + args.train_batch > total_length:
             input_X = train_X_sfl[index:]
             input_y = train_y_sfl[index:]
         else:
def train():

    # Fix Seed for Reproducibility #
    torch.manual_seed(9)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(9)

    # Samples, Weights and Results Path #
    paths = [config.samples_path, config.weights_path, config.plots_path]
    paths = [make_dirs(path) for path in paths]

    # Prepare Data Loader #
    train_horse_loader, train_zebra_loader = get_horse2zebra_loader(
        purpose='train', batch_size=config.batch_size)
    test_horse_loader, test_zebra_loader = get_horse2zebra_loader(
        purpose='test', batch_size=config.val_batch_size)
    total_batch = min(len(train_horse_loader), len(train_zebra_loader))

    # Prepare Networks #
    D_A = Discriminator()
    D_B = Discriminator()
    G_A2B = Generator()
    G_B2A = Generator()

    networks = [D_A, D_B, G_A2B, G_B2A]

    for network in networks:
        network.to(device)

    # Loss Function #
    criterion_Adversarial = nn.MSELoss()
    criterion_Cycle = nn.L1Loss()
    criterion_Identity = nn.L1Loss()

    # Optimizers #
    D_A_optim = torch.optim.Adam(D_A.parameters(),
                                 lr=config.lr,
                                 betas=(0.5, 0.999))
    D_B_optim = torch.optim.Adam(D_B.parameters(),
                                 lr=config.lr,
                                 betas=(0.5, 0.999))
    G_optim = torch.optim.Adam(chain(G_A2B.parameters(), G_B2A.parameters()),
                               lr=config.lr,
                               betas=(0.5, 0.999))

    D_A_optim_scheduler = get_lr_scheduler(D_A_optim)
    D_B_optim_scheduler = get_lr_scheduler(D_B_optim)
    G_optim_scheduler = get_lr_scheduler(G_optim)

    # Lists #
    D_losses_A, D_losses_B, G_losses = [], [], []

    # Training #
    print("Training CycleGAN started with total epoch of {}.".format(
        config.num_epochs))
    for epoch in range(config.num_epochs):
        for i, (horse,
                zebra) in enumerate(zip(train_horse_loader,
                                        train_zebra_loader)):

            # Data Preparation #
            real_A = horse.to(device)
            real_B = zebra.to(device)

            # Initialize Optimizers #
            G_optim.zero_grad()
            D_A_optim.zero_grad()
            D_B_optim.zero_grad()

            ###################
            # Train Generator #
            ###################

            set_requires_grad([D_A, D_B], requires_grad=False)

            # Adversarial Loss #
            fake_A = G_B2A(real_B)
            prob_fake_A = D_A(fake_A)
            real_labels = torch.ones(prob_fake_A.size()).to(device)
            G_mse_loss_B2A = criterion_Adversarial(prob_fake_A, real_labels)

            fake_B = G_A2B(real_A)
            prob_fake_B = D_B(fake_B)
            real_labels = torch.ones(prob_fake_B.size()).to(device)
            G_mse_loss_A2B = criterion_Adversarial(prob_fake_B, real_labels)

            # Identity Loss #
            identity_A = G_B2A(real_A)
            G_identity_loss_A = config.lambda_identity * criterion_Identity(
                identity_A, real_A)

            identity_B = G_A2B(real_B)
            G_identity_loss_B = config.lambda_identity * criterion_Identity(
                identity_B, real_B)

            # Cycle Loss #
            reconstructed_A = G_B2A(fake_B)
            G_cycle_loss_ABA = config.lambda_cycle * criterion_Cycle(
                reconstructed_A, real_A)

            reconstructed_B = G_A2B(fake_A)
            G_cycle_loss_BAB = config.lambda_cycle * criterion_Cycle(
                reconstructed_B, real_B)

            # Calculate Total Generator Loss #
            G_loss = G_mse_loss_B2A + G_mse_loss_A2B + G_identity_loss_A + G_identity_loss_B + G_cycle_loss_ABA + G_cycle_loss_BAB

            # Back Propagation and Update #
            G_loss.backward(retain_graph=True)
            G_optim.step()

            #######################
            # Train Discriminator #
            #######################

            set_requires_grad([D_A, D_B], requires_grad=True)

            ## Train Discriminator A ##
            # Real Loss #
            prob_real_A = D_A(real_A)
            real_labels = torch.ones(prob_real_A.size()).to(device)
            D_real_loss_A = criterion_Adversarial(prob_real_A, real_labels)

            # Fake Loss #
            prob_fake_A = D_A(fake_A.detach())
            fake_labels = torch.zeros(prob_fake_A.size()).to(device)
            D_fake_loss_A = criterion_Adversarial(prob_fake_A, fake_labels)

            # Calculate Total Discriminator A Loss #
            D_loss_A = config.lambda_identity * (D_real_loss_A +
                                                 D_fake_loss_A).mean()

            # Back propagation and Update #
            D_loss_A.backward(retain_graph=True)
            D_A_optim.step()

            ## Train Discriminator B ##
            # Real Loss #
            prob_real_B = D_B(real_B)
            real_labels = torch.ones(prob_real_B.size()).to(device)
            loss_real_B = criterion_Adversarial(prob_real_B, real_labels)

            # Fake Loss #
            prob_fake_B = D_B(fake_B.detach())
            fake_labels = torch.zeros(prob_fake_B.size()).to(device)
            loss_fake_B = criterion_Adversarial(prob_fake_B, fake_labels)

            # Calculate Total Discriminator B Loss #
            D_loss_B = config.lambda_identity * (loss_real_B +
                                                 loss_fake_B).mean()

            # Back propagation and Update #
            D_loss_B.backward(retain_graph=True)
            D_B_optim.step()

            # Add items to Lists #
            D_losses_A.append(D_loss_A.item())
            D_losses_B.append(D_loss_B.item())
            G_losses.append(G_loss.item())

            ####################
            # Print Statistics #
            ####################

            if (i + 1) % config.print_every == 0:
                print(
                    "CycleGAN | Epoch [{}/{}] | Iterations [{}/{}] | D_A Loss {:.4f} | D_B Loss {:.4f} | G Loss {:.4f}"
                    .format(epoch + 1, config.num_epochs, i + 1, total_batch,
                            np.average(D_losses_A), np.average(D_losses_B),
                            np.average(G_losses)))

                # Save Sample Images #
                sample_images(test_horse_loader, test_zebra_loader, G_A2B,
                              G_B2A, epoch, config.samples_path)

        # Adjust Learning Rate #
        D_A_optim_scheduler.step()
        D_B_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Model Weights #
        if (epoch + 1) % config.save_every == 0:
            torch.save(
                G_A2B.state_dict(),
                os.path.join(
                    config.weights_path,
                    'CycleGAN_Generator_A2B_Epoch_{}.pkl'.format(epoch + 1)))
            torch.save(
                G_B2A.state_dict(),
                os.path.join(
                    config.weights_path,
                    'CycleGAN_Generator_B2A_Epoch_{}.pkl'.format(epoch + 1)))

    # Make a GIF file #
    make_gifs_train("CycleGAN", config.samples_path)

    # Plot Losses #
    plot_losses(D_losses_A, D_losses_B, G_losses, config.num_epochs,
                config.plots_path)

    print("Training finished.")
def main(args):
    if args.model == 'gta':
        model_file = './trained_models/gta_source.pt'
        out_file = './trained_models/gta_adda.pt'
        out_ftrs = 4375

        model = GTANet().to(device)
        model.load_state_dict(torch.load(model_file))
        model.eval()
        set_requires_grad(model, False)
        source_model = model.feature_extractor
        clf = model

        model_2 = GTANet().to(device)
        model_2.load_state_dict(torch.load(model_file))
        target_model = model_2.feature_extractor

    elif args.model == 'gta-res':
        model_file = './trained_models/gta_res_source.pt'
        out_file = './trained_models/gta_res_adda.pt'

        model = GTARes18Net(9, pretrained=False).to(device)
        out_ftrs = model.fc.in_features
        model.load_state_dict(torch.load(model_file))
        model.eval()
        set_requires_grad(model, False)

        source_model = model.feature_extractor
        clf = model

        model_2 = GTARes18Net(9, pretrained=False).to(device)
        model_2.load_state_dict(torch.load(model_file))
        target_model = model_2.feature_extractor

    elif args.model == 'gta-vgg':
        model_file = './trained_models/gta_vgg_source.pt'
        out_file = './trained_models/gta_vgg_adda.pt'

        model = GTAVGG11Net(9, pretrained=False).to(device)
        out_ftrs = model.classifier[0].in_features  # should be 512 * 7 * 7
        model.load_state_dict(torch.load(model_file))
        model.eval()
        set_requires_grad(model, False)

        def source_model(x):
            x = model.features(x)
            x = model.avgpool(x)
            x = torch.flatten(x, 1)
            return x

        clf = model

        model_2 = GTAVGG11Net(9, pretrained=False).to(device)
        model_2.load_state_dict(torch.load(model_file))

        def target_model(x):
            x = model_2.features(x)
            x = model_2.avgpool(x)
            x = torch.flatten(x, 1)
            return x

    else:
        raise ValueError(f'Unknown model type {args.model}')

    discriminator = nn.Sequential(
        nn.Linear(out_ftrs, 64),
        nn.ReLU(),
        nn.Linear(64, 16),
        nn.ReLU(),
        nn.Linear(16, 1),
    ).to(device)

    half_batch = args.batch_size // 2
    target_dataset = ImageFolder('./data',
                                 transform=Compose([
                                     Resize((398, 224)),
                                     RandomCrop(224),
                                     RandomHorizontalFlip(),
                                     ToTensor(),
                                     Normalize([0.485, 0.456, 0.406],
                                               [0.229, 0.224, 0.225]),
                                 ]))
    target_loader = DataLoader(target_dataset,
                               batch_size=half_batch,
                               shuffle=True,
                               num_workers=1,
                               pin_memory=True)

    source_dataset = ImageFolder('./t_data',
                                 transform=Compose([
                                     RandomCrop(224,
                                                pad_if_needed=True,
                                                padding_mode='reflect'),
                                     RandomHorizontalFlip(),
                                     ToTensor(),
                                     Normalize([0.485, 0.456, 0.406],
                                               [0.229, 0.224, 0.225]),
                                 ]))
    source_loader = DataLoader(source_dataset,
                               batch_size=half_batch,
                               shuffle=True,
                               num_workers=1,
                               pin_memory=True)

    discriminator_optim = torch.optim.Adam(discriminator.parameters())
    target_optim = torch.optim.Adam(model_2.parameters())
    criterion = nn.BCEWithLogitsLoss()

    for epoch in range(1, args.epochs + 1):
        batch_iterator = zip(loop_iterable(source_loader),
                             loop_iterable(target_loader))

        total_loss = 0
        total_accuracy = 0
        for _ in trange(args.iterations, leave=False):
            # Train discriminator
            set_requires_grad(model_2, requires_grad=False)
            set_requires_grad(discriminator, requires_grad=True)
            for _ in range(args.k_disc):
                (source_x, _), (target_x, _) = next(batch_iterator)
                source_x, target_x = source_x.to(device), target_x.to(device)

                source_features = source_model(source_x).view(
                    source_x.shape[0], -1)
                target_features = target_model(target_x).view(
                    target_x.shape[0], -1)

                discriminator_x = torch.cat([source_features, target_features])
                discriminator_y = torch.cat([
                    torch.ones(source_x.shape[0], device=device),
                    torch.zeros(target_x.shape[0], device=device)
                ])

                preds = discriminator(discriminator_x).squeeze()
                loss = criterion(preds, discriminator_y)

                discriminator_optim.zero_grad()
                loss.backward()
                discriminator_optim.step()

                total_loss += loss.item()
                total_accuracy += ((preds > 0).long() == discriminator_y.long()
                                  ).float().mean().item()

            # Train classifier
            set_requires_grad(model_2, requires_grad=True)
            set_requires_grad(discriminator, requires_grad=False)
            for _ in range(args.k_clf):
                _, (target_x, _) = next(batch_iterator)
                target_x = target_x.to(device)
                target_features = target_model(target_x).view(
                    target_x.shape[0], -1)

                # flipped labels
                discriminator_y = torch.ones(target_x.shape[0], device=device)

                preds = discriminator(target_features).squeeze()
                loss = criterion(preds, discriminator_y)

                target_optim.zero_grad()
                loss.backward()
                target_optim.step()

        mean_loss = total_loss / (args.iterations * args.k_disc)
        mean_accuracy = total_accuracy / (args.iterations * args.k_disc)
        tqdm.write(f'EPOCH {epoch:03d}: discriminator_loss={mean_loss:.4f}, '
                   f'discriminator_accuracy={mean_accuracy:.4f}')

        # Create the full target model and save it
        if args.model == 'gta':
            clf.feature_extractor = target_model
        elif args.model == 'gta-res':
            clf.conv1 = model_2.conv1
            clf.bn1 = model_2.bn1
            clf.relu = model_2.relu
            clf.maxpool = model_2.maxpool

            clf.layer1 = model_2.layer1
            clf.layer2 = model_2.layer2
            clf.layer3 = model_2.layer3
            clf.layer4 = model_2.layer4

            clf.avgpool = model_2.avgpool

        torch.save(clf.state_dict(), out_file)
Beispiel #17
0
def main():
    # Get training options
    opt = get_opt()

    # Define the networks
    # netG_A: used to transfer image from domain A to domain B
    # netG_B: used to transfer image from domain B to domain A
    netG_A = networks.Generator(opt.input_nc, opt.output_nc, opt.ngf,
                                opt.n_res, opt.dropout)
    netG_B = networks.Generator(opt.output_nc, opt.input_nc, opt.ngf,
                                opt.n_res, opt.dropout)
    if opt.u_net:
        netG_A = networks.U_net(opt.input_nc, opt.output_nc, opt.ngf)
        netG_B = networks.U_net(opt.output_nc, opt.input_nc, opt.ngf)

    # netD_A: used to test whether an image is from domain B
    # netD_B: used to test whether an image is from domain A
    netD_A = networks.Discriminator(opt.input_nc, opt.ndf)
    netD_B = networks.Discriminator(opt.output_nc, opt.ndf)

    # Initialize the networks
    if opt.cuda:
        netG_A.cuda()
        netG_B.cuda()
        netD_A.cuda()
        netD_B.cuda()
    utils.init_weight(netG_A)
    utils.init_weight(netG_B)
    utils.init_weight(netD_A)
    utils.init_weight(netD_B)

    if opt.pretrained:
        netG_A.load_state_dict(torch.load('pretrained/netG_A.pth'))
        netG_B.load_state_dict(torch.load('pretrained/netG_B.pth'))
        netD_A.load_state_dict(torch.load('pretrained/netD_A.pth'))
        netD_B.load_state_dict(torch.load('pretrained/netD_B.pth'))

    # Define the loss functions
    criterion_GAN = utils.GANLoss()
    if opt.cuda:
        criterion_GAN.cuda()

    criterion_cycle = torch.nn.L1Loss()
    # Alternatively, can try MSE cycle consistency loss
    #criterion_cycle = torch.nn.MSELoss()
    criterion_identity = torch.nn.L1Loss()

    # Define the optimizers
    optimizer_G = torch.optim.Adam(itertools.chain(netG_A.parameters(),
                                                   netG_B.parameters()),
                                   lr=opt.lr,
                                   betas=(opt.beta1, 0.999))
    optimizer_D_A = torch.optim.Adam(netD_A.parameters(),
                                     lr=opt.lr,
                                     betas=(opt.beta1, 0.999))
    optimizer_D_B = torch.optim.Adam(netD_B.parameters(),
                                     lr=opt.lr,
                                     betas=(opt.beta1, 0.999))

    # Create learning rate schedulers
    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
        optimizer_G,
        lr_lambda=utils.Lambda_rule(opt.epoch, opt.n_epochs,
                                    opt.n_epochs_decay).step)
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_A,
        lr_lambda=utils.Lambda_rule(opt.epoch, opt.n_epochs,
                                    opt.n_epochs_decay).step)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_B,
        lr_lambda=utils.Lambda_rule(opt.epoch, opt.n_epochs,
                                    opt.n_epochs_decay).step)

    Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
    input_A = Tensor(opt.batch_size, opt.input_nc, opt.sizeh, opt.sizew)
    input_B = Tensor(opt.batch_size, opt.output_nc, opt.sizeh, opt.sizew)

    # Define two image pools to store generated images
    fake_A_pool = utils.ImagePool()
    fake_B_pool = utils.ImagePool()

    # Define the transform, and load the data
    transform = transforms.Compose([
        transforms.Resize((opt.sizeh, opt.sizew)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, ), (0.5, ))
    ])
    dataloader = DataLoader(ImageDataset(opt.rootdir,
                                         transform=transform,
                                         mode='train'),
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.n_cpu)

    # numpy arrays to store the loss of epoch
    loss_G_array = np.zeros(opt.n_epochs + opt.n_epochs_decay)
    loss_D_A_array = np.zeros(opt.n_epochs + opt.n_epochs_decay)
    loss_D_B_array = np.zeros(opt.n_epochs + opt.n_epochs_decay)

    # Training
    for epoch in range(opt.epoch, opt.n_epochs + opt.n_epochs_decay):
        start = time.strftime("%H:%M:%S")
        print("current epoch :", epoch, " start time :", start)
        # Empty list to store the loss of each mini-batch
        loss_G_list = []
        loss_D_A_list = []
        loss_D_B_list = []

        for i, batch in enumerate(dataloader):
            if i % 50 == 1:
                print("current step: ", i)
                current = time.strftime("%H:%M:%S")
                print("current time :", current)
                print("last loss G:", loss_G_list[-1], "last loss D_A",
                      loss_D_A_list[-1], "last loss D_B", loss_D_B_list[-1])
            real_A = input_A.copy_(batch['A'])
            real_B = input_B.copy_(batch['B'])

            # Train the generator
            optimizer_G.zero_grad()

            # Compute fake images and reconstructed images
            fake_B = netG_A(real_A)
            fake_A = netG_B(real_B)

            if opt.identity_loss != 0:
                same_B = netG_A(real_B)
                same_A = netG_B(real_A)

            # discriminators require no gradients when optimizing generators
            utils.set_requires_grad([netD_A, netD_B], False)

            # Identity loss
            if opt.identity_loss != 0:
                loss_identity_A = criterion_identity(
                    same_A, real_A) * opt.identity_loss
                loss_identity_B = criterion_identity(
                    same_B, real_B) * opt.identity_loss

            # GAN loss
            prediction_fake_B = netD_B(fake_B)
            loss_gan_B = criterion_GAN(prediction_fake_B, True)
            prediction_fake_A = netD_A(fake_A)
            loss_gan_A = criterion_GAN(prediction_fake_A, True)

            # Cycle consistent loss
            recA = netG_B(fake_B)
            recB = netG_A(fake_A)
            loss_cycle_A = criterion_cycle(recA, real_A) * opt.cycle_loss
            loss_cycle_B = criterion_cycle(recB, real_B) * opt.cycle_loss

            # total loss without the identity loss
            loss_G = loss_gan_B + loss_gan_A + loss_cycle_A + loss_cycle_B

            if opt.identity_loss != 0:
                loss_G += loss_identity_A + loss_identity_B

            loss_G_list.append(loss_G.item())
            loss_G.backward()
            optimizer_G.step()

            # Train the discriminator
            utils.set_requires_grad([netD_A, netD_B], True)

            # Train the discriminator D_A
            optimizer_D_A.zero_grad()
            # real images
            pred_real = netD_A(real_A)
            loss_D_real = criterion_GAN(pred_real, True)

            # fake images
            fake_A = fake_A_pool.query(fake_A)
            pred_fake = netD_A(fake_A.detach())
            loss_D_fake = criterion_GAN(pred_fake, False)

            #total loss
            loss_D_A = (loss_D_real + loss_D_fake) * 0.5
            loss_D_A_list.append(loss_D_A.item())
            loss_D_A.backward()
            optimizer_D_A.step()

            # Train the discriminator D_B
            optimizer_D_B.zero_grad()
            # real images
            pred_real = netD_B(real_B)
            loss_D_real = criterion_GAN(pred_real, True)

            # fake images
            fake_B = fake_B_pool.query(fake_B)
            pred_fake = netD_B(fake_B.detach())
            loss_D_fake = criterion_GAN(pred_fake, False)

            # total loss
            loss_D_B = (loss_D_real + loss_D_fake) * 0.5
            loss_D_B_list.append(loss_D_B.item())
            loss_D_B.backward()
            optimizer_D_B.step()

        # Update the learning rate
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()

        # Save models checkpoints
        torch.save(netG_A.state_dict(), 'model/netG_A.pth')
        torch.save(netG_B.state_dict(), 'model/netG_B.pth')
        torch.save(netD_A.state_dict(), 'model/netD_A.pth')
        torch.save(netD_B.state_dict(), 'model/netD_B.pth')

        # Save other checkpoint information
        checkpoint = {
            'epoch': epoch,
            'optimizer_G': optimizer_G.state_dict(),
            'optimizer_D_A': optimizer_D_A.state_dict(),
            'optimizer_D_B': optimizer_D_B.state_dict(),
            'lr_scheduler_G': lr_scheduler_G.state_dict(),
            'lr_scheduler_D_A': lr_scheduler_D_A.state_dict(),
            'lr_scheduler_D_B': lr_scheduler_D_B.state_dict()
        }
        torch.save(checkpoint, 'model/checkpoint.pth')

        # Update the numpy arrays that record the loss
        loss_G_array[epoch] = sum(loss_G_list) / len(loss_G_list)
        loss_D_A_array[epoch] = sum(loss_D_A_list) / len(loss_D_A_list)
        loss_D_B_array[epoch] = sum(loss_D_B_list) / len(loss_D_B_list)
        np.savetxt('model/loss_G.txt', loss_G_array)
        np.savetxt('model/loss_D_A.txt', loss_D_A_array)
        np.savetxt('model/loss_D_b.txt', loss_D_B_array)

        if epoch % 10 == 9:
            torch.save(netG_A.state_dict(),
                       'model/netG_A' + str(epoch) + '.pth')
            torch.save(netG_B.state_dict(),
                       'model/netG_B' + str(epoch) + '.pth')
            torch.save(netD_A.state_dict(),
                       'model/netD_A' + str(epoch) + '.pth')
            torch.save(netD_B.state_dict(),
                       'model/netD_B' + str(epoch) + '.pth')

        end = time.strftime("%H:%M:%S")
        print("current epoch :", epoch, " end time :", end)
        print("G loss :", loss_G_array[epoch], "D_A loss :",
              loss_D_A_array[epoch], "D_B loss :", loss_D_B_array[epoch])
Beispiel #18
0
def main(args):
    source_model = Net().to(device)
    source_model.load_state_dict(torch.load(args.MODEL_FILE))
    source_model.eval()
    set_requires_grad(source_model, requires_grad=False)

    clf = source_model
    source_model = source_model.feature_extractor

    target_model = Net().to(device)
    target_model.load_state_dict(torch.load(args.MODEL_FILE))
    target_model = target_model.feature_extractor

    classifier = clf.classifier

    discriminator = nn.Sequential(nn.Linear(EXTRACTED_FEATURE_DIM, 64),
                                  nn.ReLU(), nn.BatchNorm1d(64),
                                  nn.Linear(64, 1), nn.Sigmoid()).to(device)

    #half_batch = args.batch_size // 2

    batch_size = args.batch_size

    # X_source, y_source = preprocess_train()
    X_source, y_source = preprocess_train_single(1)
    source_dataset = torch.utils.data.TensorDataset(X_source, y_source)

    source_loader = DataLoader(source_dataset,
                               batch_size=batch_size,
                               shuffle=False,
                               num_workers=1,
                               pin_memory=True)

    X_target, y_target = preprocess_test(args.person)
    target_dataset = torch.utils.data.TensorDataset(X_target, y_target)
    target_loader = DataLoader(target_dataset,
                               batch_size=batch_size,
                               shuffle=False,
                               num_workers=1,
                               pin_memory=True)

    discriminator_optim = torch.optim.Adam(discriminator.parameters())
    target_optim = torch.optim.Adam(target_model.parameters(), lr=3e-6)
    criterion = nn.BCEWithLogitsLoss()
    criterion_class = nn.CrossEntropyLoss()

    best_tar_acc = test(args, clf)
    final_accs = []

    for epoch in range(1, args.epochs + 1):
        source_loader = DataLoader(source_loader.dataset,
                                   batch_size=batch_size,
                                   shuffle=True)
        target_loader = DataLoader(target_loader.dataset,
                                   batch_size=batch_size,
                                   shuffle=True)
        batch_iterator = zip(loop_iterable(source_loader),
                             loop_iterable(target_loader))

        total_loss = 0
        adv_loss = 0
        total_accuracy = 0
        second_acc = 0
        total_class_loss = 0
        for _ in trange(args.iterations, leave=False):
            # Train discriminator
            set_requires_grad(target_model, requires_grad=False)
            set_requires_grad(discriminator, requires_grad=True)
            discriminator.train()
            for _ in range(args.k_disc):
                (source_x, source_y), (target_x, _) = next(batch_iterator)
                source_y = source_y.to(device).view(-1)
                source_x, target_x = source_x.to(device), target_x.to(device)

                source_features = source_model(source_x).view(
                    source_x.shape[0], -1)
                target_features = target_model(target_x).view(
                    target_x.shape[0], -1)

                discriminator_x = torch.cat([source_features, target_features])
                discriminator_y = torch.cat([
                    torch.ones(source_x.shape[0], device=device),
                    torch.zeros(target_x.shape[0], device=device)
                ])

                preds = discriminator(discriminator_x).squeeze()
                loss = criterion(preds, discriminator_y)

                discriminator_optim.zero_grad()
                loss.backward()
                discriminator_optim.step()

                total_loss += loss.item()
                total_accuracy += ((preds >= 0.5).long() == discriminator_y.
                                   long()).float().mean().item()

            # Train feature extractor
            set_requires_grad(target_model, requires_grad=True)
            set_requires_grad(discriminator, requires_grad=False)
            target_model.train()
            for _ in range(args.k_clf):
                _, (target_x, _) = next(batch_iterator)
                target_x = target_x.to(device)
                target_features = target_model(target_x).view(
                    target_x.shape[0], -1)
                source_features = target_model(source_x).view(
                    source_x.shape[0], -1)
                source_pred = classifier(source_features)  # (batch_size, 4)

                # flipped labels
                discriminator_y = torch.ones(target_x.shape[0], device=device)

                preds = discriminator(target_features).squeeze()
                second_acc += ((preds >= 0.5).long() == discriminator_y.long()
                               ).float().mean().item()

                loss_adv = criterion(preds, discriminator_y)
                adv_loss += loss_adv.item()
                loss_class = criterion_class(source_pred, source_y)
                total_class_loss += loss_class.item()
                loss = loss_adv  #+ 0.001*loss_class

                target_optim.zero_grad()
                loss.backward()
                target_optim.step()

        mean_loss = total_loss / (args.iterations * args.k_disc)
        mean_adv_loss = adv_loss / (args.iterations * args.k_clf)
        total_class_loss = total_class_loss / (args.iterations * args.k_clf)
        dis_accuracy = total_accuracy / (args.iterations * args.k_disc)
        sec_acc = second_acc / (args.iterations * args.k_clf)
        clf.feature_extractor = target_model
        tar_accuarcy = test(args, clf)
        final_accs.append(tar_accuarcy)
        if tar_accuarcy > best_tar_acc:
            best_tar_acc = tar_accuarcy
            torch.save(clf.state_dict(), 'trained_models/adda.pt')

        tqdm.write(
            f'EPOCH {epoch:03d}: discriminator_loss={mean_loss:.4f}, adv_loss = {mean_adv_loss:.4f}, '
            f'discriminator_accuracy={dis_accuracy:.4f}, tar_accuary = {tar_accuarcy:.4f}, best_accuracy = {best_tar_acc:.4f}, '
            f'sec_acc = {sec_acc:.4f}, total_class_loss: {total_class_loss:.4f}'
        )

        # Create the full target model and save it
        clf.feature_extractor = target_model
        #torch.save(clf.state_dict(), 'trained_models/adda.pt')
    jd = {"test_acc": final_accs}
    with open(str(args.seed) + '/acc' + str(args.person) + '.json', 'w') as f:
        json.dump(jd, f)
Beispiel #19
0
    def _do_epoch(self):

        set_mode(self.main_model, "train")
        set_mode(self.dis_model, "train")
        set_mode(self.c_model, "train")
        set_mode(self.cp_model, "train")

        set_lambda([self.dis_model], [self.args.lbd_d])
        set_lambda([self.c_model, self.cp_model],
                   [self.args.lbd_c, self.args.lbd_cp])
        loader_iter_list = []
        loader_size_list = []

        if self.current_epoch < self.args.warmup_step:
            aux_weight = self.args.warmup_weight
            main_weight = self.args.warmup_weight
        else:
            aux_weight = 1
            main_weight = 1

        for loader in self.source_loader_list:
            loader_iter_list.append(enumerate(loader))
            loader_size_list.append(len(loader))

        for it in range(max(loader_size_list)):
            data = []
            labels = []
            domains = []
            for idx, iter_ in zip(range(self.num_domains), loader_iter_list):
                try:
                    item = iter_.__next__()
                except StopIteration:
                    loader_iter_list[idx] = enumerate(
                        self.source_loader_list[idx])
                    item = loader_iter_list[idx].__next__()
                data.append(item[1][0])
                labels.append(item[1][1])
                domains.append(torch.ones(labels[-1].size(0)).long() * idx)
            data = torch.cat(data, dim=0).to(self.device)
            labels = torch.cat(labels, dim=0).to(self.device)
            domains = torch.cat(domains, dim=0).to(self.device)

            set_requires_grad(self.main_model, False)
            set_requires_grad(self.c_model, True)
            _, feature = self.main_model(data)
            c_loss_self = self._compute_cls_loss(
                self.c_model, feature.detach(), labels, domains,
                mode="self") * aux_weight
            self.optimizer.zero_grad()
            c_loss_self.backward()
            self.optimizer.step()

            set_requires_grad(
                [self.main_model, self.dis_model, self.c_model, self.cp_model],
                True)
            class_logit, feature = self.main_model(data)
            main_loss = F.cross_entropy(class_logit, labels) * main_weight

            dis_loss = self._compute_dis_loss(feature, domains) * aux_weight

            set_requires_grad(self.c_model, False)
            c_loss_others = self._compute_cls_loss(
                self.c_model, feature, labels, domains,
                mode="others") * aux_weight

            cp_loss = self._compute_cls_loss(
                self.cp_model, feature, labels, domains,
                mode="self") * aux_weight

            loss = dis_loss + c_loss_others + cp_loss + main_loss

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            loss += c_loss_self

            message = "epoch %d iter %d: all %.6f main %.6f dis %.6f c_self %.6f c_others %.6f cp %.6f\n" % (
                self.current_epoch, it, loss.data, main_loss.data,
                dis_loss.data, c_loss_self.data, c_loss_others.data,
                cp_loss.data)
            with open(self.log_file, "a") as fid:
                fid.write(message)
            print(message)

            del loss, main_loss, dis_loss, c_loss_self, c_loss_others, cp_loss

        self.main_model.eval()
        with torch.no_grad():
            with open(self.log_file, "a") as fid:
                for phase, loader in self.test_loaders.items():
                    class_correct, all_domains = self.do_test(loader)
                    class_correct = class_correct.float()
                    class_acc = class_correct.mean() * 100.0
                    self.results[phase][self.current_epoch] = class_acc
                    if phase == "val":
                        message = "epoch %d: val_all_acc %.5f" % (
                            self.current_epoch, class_acc)
                        for i in range(self.num_domains):
                            cc_i = class_correct[all_domains == i]
                            ca_i = cc_i.mean() * 100.0
                            message += " val_%s_acc %.5f" % (
                                self.args.source[i], ca_i)
                        message += "\n"
                        fid.write(message)
                        print(message)
                    elif phase == "test":
                        message = "epoch %d: test_acc %.5f\n" % (
                            self.current_epoch, class_acc)
                        fid.write(message)
                        print(message)
Beispiel #20
0
    parser.add_argument("--model-path", type=str, required=True)
    parser.add_argument("--config-path", type=str, required=True)
    parser.add_argument("--save-dir-path", type=str, default=".")
    parser.add_argument("--tol", type=float, default=0)
    parser.add_argument("--batch-size", type=int, default=512)
    parser.add_argument("--distance", default="l1", choices=["l1", "l2"])
    args = parser.parse_args()

    cfg, G, lidar, device = utils.setup(
        args.model_path,
        args.config_path,
        ema=True,
        fix_noise=True,
    )

    utils.set_requires_grad(G, False)
    G = DP(G)

    # hyperparameters
    num_step = 1000
    perturb_latent = True
    noise_ratio = 0.75
    noise_sigma = 1.0
    lr_rampup_ratio = 0.05
    lr_rampdown_ratio = 0.25

    # prepare reference
    dataset = define_dataset(cfg.dataset, phase="test")
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
Beispiel #21
0
def main():
    # Get training options
    opt = get_opt()

    device = torch.device("cuda") if opt.cuda else torch.device("cpu")

    # Define the networks
    # netG_A: used to transfer image from domain A to domain B
    netG_A = networks.Generator(opt.input_nc, opt.output_nc, opt.ngf, opt.n_res, opt.dropout)
    if opt.u_net:
        netG_A = networks.U_net(opt.input_nc, opt.output_nc, opt.ngf)

    # netD_B: used to test whether an image is from domain A
    netD_B = networks.Discriminator(opt.input_nc + opt.output_nc, opt.ndf)

    # Initialize the networks
    if opt.cuda:
        netG_A.cuda()
        netD_B.cuda()
    utils.init_weight(netG_A)
    utils.init_weight(netD_B)

    if opt.pretrained:
        netG_A.load_state_dict(torch.load('pretrained/netG_A.pth'))
        netD_B.load_state_dict(torch.load('pretrained/netD_B.pth'))


    # Define the loss functions
    criterion_GAN = utils.GANLoss()
    if opt.cuda:
        criterion_GAN.cuda()

    criterion_l1 = torch.nn.L1Loss()

    # Define the optimizers
    optimizer_G = torch.optim.Adam(netG_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
    optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

    # Create learning rate schedulers
    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda = utils.Lambda_rule(opt.epoch, opt.n_epochs, opt.n_epochs_decay).step)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda = utils.Lambda_rule(opt.epoch, opt.n_epochs, opt.n_epochs_decay).step)


    # Define the transform, and load the data
    transform = transforms.Compose([transforms.Resize((opt.sizeh, opt.sizew)),
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,))])
    dataloader = DataLoader(PairedImage(opt.rootdir, transform = transform, mode = 'train'), batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu)

    # numpy arrays to store the loss of epoch
    loss_G_array = np.zeros(opt.n_epochs + opt.n_epochs_decay)
    loss_D_B_array = np.zeros(opt.n_epochs + opt.n_epochs_decay)

    # Training
    for epoch in range(opt.epoch, opt.n_epochs + opt.n_epochs_decay):
        start = time.strftime("%H:%M:%S")
        print("current epoch :", epoch, " start time :", start)
        # Empty list to store the loss of each mini-batch
        loss_G_list = []
        loss_D_B_list = []

        for i, batch in enumerate(dataloader):
            if i % 20 == 1:
                print("current step: ", i)
                current = time.strftime("%H:%M:%S")
                print("current time :", current)
                print("last loss G_A:", loss_G_list[-1],  "last loss D_B:", loss_D_B_list[-1])

            real_A = batch['A'].to(device)
            real_B = batch['B'].to(device)

            # Train the generator
            utils.set_requires_grad([netG_A], True)
            optimizer_G.zero_grad()

            # Compute fake images and reconstructed images
            fake_B = netG_A(real_A)


            # discriminators require no gradients when optimizing generators
            utils.set_requires_grad([netD_B], False)


            # GAN loss
            prediction_fake_B = netD_B(torch.cat((fake_B, real_A), dim=1))
            loss_gan = criterion_GAN(prediction_fake_B, True)

            #L1 loss
            loss_l1 = criterion_l1(real_B, fake_B) * opt.l1_loss

            # total loss without the identity loss
            loss_G = loss_gan + loss_l1

            loss_G_list.append(loss_G.item())
            loss_G.backward()
            optimizer_G.step()

            # Train the discriminator
            utils.set_requires_grad([netG_A], False)
            utils.set_requires_grad([netD_B], True)

            # Train the discriminator D_B
            optimizer_D_B.zero_grad()
            # real images
            pred_real = netD_B(torch.cat((real_B, real_A), dim=1))
            loss_D_real = criterion_GAN(pred_real, True)

            # fake images
            fake_B = netG_A(real_A)
            pred_fake = netD_B(torch.cat((fake_B, real_A), dim=1))
            loss_D_fake = criterion_GAN(pred_fake, False)

            # total loss
            loss_D_B = (loss_D_real + loss_D_fake) * 0.5
            loss_D_B_list.append(loss_D_B.item())
            loss_D_B.backward()
            optimizer_D_B.step()

        # Update the learning rate
        lr_scheduler_G.step()
        lr_scheduler_D_B.step()

        # Save models checkpoints
        torch.save(netG_A.state_dict(), 'model/netG_A_pix.pth')
        torch.save(netD_B.state_dict(), 'model/netD_B_pix.pth')

        # Save other checkpoint information
        checkpoint = {'epoch': epoch,
                      'optimizer_G': optimizer_G.state_dict(),
                      'optimizer_D_B': optimizer_D_B.state_dict(),
                      'lr_scheduler_G': lr_scheduler_G.state_dict(),
                      'lr_scheduler_D_B': lr_scheduler_D_B.state_dict()}
        torch.save(checkpoint, 'model/checkpoint.pth')



        # Update the numpy arrays that record the loss
        loss_G_array[epoch] = sum(loss_G_list) / len(loss_G_list)
        loss_D_B_array[epoch] = sum(loss_D_B_list) / len(loss_D_B_list)
        np.savetxt('model/loss_G.txt', loss_G_array)
        np.savetxt('model/loss_D_B.txt', loss_D_B_array)

        end = time.strftime("%H:%M:%S")
        print("current epoch :", epoch, " end time :", end)
        print("G loss :", loss_G_array[epoch], "D_B loss :", loss_D_B_array[epoch])
Beispiel #22
0
def main(args):
    final_accs = []
    source_models = [Net().to(device) for _ in range(10)]
    for idx in range(len(source_models)):
        source_models[idx].load_state_dict(torch.load(args.MODEL_FILE))
        source_models[idx].eval()
        set_requires_grad(source_models[idx], requires_grad=False)

    clfs = [source_model for source_model in source_models]
    source_models = [
        source_model.feature_extractor for source_model in source_models
    ]

    target_models = [Net().to(device) for _ in range(10)]
    for idx in range(len(target_models)):
        target_models[idx].load_state_dict(torch.load(args.MODEL_FILE))
        target_models[idx] = target_models[idx].feature_extractor

    discriminators = [
        nn.Sequential(nn.Linear(EXTRACTED_FEATURE_DIM, 64), nn.ReLU(),
                      nn.BatchNorm1d(64), nn.Linear(64, 1),
                      nn.Sigmoid()).to(device) for _ in range(10)
    ]

    batch_size = args.batch_size
    discriminator_optims = [
        torch.optim.Adam(discriminators[idx].parameters(), lr=1e-5)
        for idx in range(10)
    ]
    target_optims = [
        torch.optim.Adam(target_models[idx].parameters(), lr=1e-5)
        for idx in range(10)
    ]
    criterion = nn.BCEWithLogitsLoss()

    source_loaders = []
    target_loaders = []
    for idx in range(10):
        X_source, y_source = preprocess_train_single(idx)
        source_dataset = torch.utils.data.TensorDataset(X_source, y_source)

        source_loader = DataLoader(source_dataset,
                                   batch_size=batch_size,
                                   shuffle=False,
                                   num_workers=1,
                                   pin_memory=True)
        source_loaders.append(source_loader)

        X_target, y_target = preprocess_test(args.person)
        target_dataset = torch.utils.data.TensorDataset(X_target, y_target)
        target_loader = DataLoader(target_dataset,
                                   batch_size=batch_size,
                                   shuffle=False,
                                   num_workers=1,
                                   pin_memory=True)
        target_loaders.append(target_loader)

    best_voting_acc = test_all(clfs)
    best_tar_accs = [0.0] * 10

    for epoch in range(1, args.epochs + 1):
        source_loaders = [
            DataLoader(source_loaders[idx].dataset,
                       batch_size=batch_size,
                       shuffle=True) for idx in range(10)
        ]
        target_loaders = [
            DataLoader(target_loaders[idx].dataset,
                       batch_size=batch_size,
                       shuffle=True) for idx in range(10)
        ]
        for idx in range(10):
            source_loader = source_loaders[idx]
            target_loader = target_loaders[idx]
            batch_iterator = zip(loop_iterable(source_loader),
                                 loop_iterable(target_loader))
            target_model = target_models[idx]
            discriminator = discriminators[idx]
            source_model = source_models[idx]
            clf = clfs[idx]
            total_loss = 0
            adv_loss = 0
            total_accuracy = 0
            second_acc = 0
            for _ in trange(args.iterations, leave=False):
                # Train discriminator
                set_requires_grad(target_model, requires_grad=False)
                set_requires_grad(discriminator, requires_grad=True)
                discriminator.train()
                for _ in range(args.k_disc):
                    (source_x, _), (target_x, _) = next(batch_iterator)
                    source_x, target_x = source_x.to(device), target_x.to(
                        device)

                    source_features = source_model(source_x).view(
                        source_x.shape[0], -1)
                    target_features = target_model(target_x).view(
                        target_x.shape[0], -1)

                    discriminator_x = torch.cat(
                        [source_features, target_features])
                    discriminator_y = torch.cat([
                        torch.ones(source_x.shape[0], device=device),
                        torch.zeros(target_x.shape[0], device=device)
                    ])

                    preds = discriminator(discriminator_x).squeeze()
                    loss = criterion(preds, discriminator_y)

                    discriminator_optims[idx].zero_grad()
                    loss.backward()
                    discriminator_optims[idx].step()

                    total_loss += loss.item()
                    total_accuracy += ((preds >= 0.5).long(
                    ) == discriminator_y.long()).float().mean().item()

                # Train classifier
                set_requires_grad(target_model, requires_grad=True)
                set_requires_grad(discriminator, requires_grad=False)
                target_model.train()
                for _ in range(args.k_clf):
                    _, (target_x, _) = next(batch_iterator)
                    target_x = target_x.to(device)
                    target_features = target_model(target_x).view(
                        target_x.shape[0], -1)

                    # flipped labels
                    discriminator_y = torch.ones(target_x.shape[0],
                                                 device=device)

                    preds = discriminator(target_features).squeeze()
                    second_acc += ((preds >= 0.5).long() == discriminator_y.
                                   long()).float().mean().item()

                    loss = criterion(preds, discriminator_y)
                    adv_loss += loss.item()

                    target_optims[idx].zero_grad()
                    loss.backward()
                    target_optims[idx].step()

            mean_loss = total_loss / (args.iterations * args.k_disc)
            mean_adv_loss = adv_loss / (args.iterations * args.k_clf)
            dis_accuracy = total_accuracy / (args.iterations * args.k_disc)
            sec_acc = second_acc / (args.iterations * args.k_clf)
            clf.feature_extractor = target_model
            tar_accuarcy = test(args, clf)
            if tar_accuarcy > best_tar_accs[idx]:
                best_tar_accs[idx] = tar_accuarcy
                torch.save(clf.state_dict(),
                           'trained_models/adda' + str(idx) + '.pt')

            tqdm.write(
                f'EPOCH {epoch:03d}: discriminator_loss={mean_loss:.4f}, adv_loss = {mean_adv_loss:.4f}, '
                f'discriminator_accuracy={dis_accuracy:.4f}, tar_accuary = {tar_accuarcy:.4f}, best_accuracy = {best_tar_accs[idx]:.4f}, sec_acc = {sec_acc:.4f}'
            )

            # Create the full target model and save it
            clf.feature_extractor = target_model
            #torch.save(clf.state_dict(), 'trained_models/adda.pt')
        acc = test_all(clfs)
        final_accs.append(acc)
        if acc > best_voting_acc:
            best_voting_acc = acc
        print("In epoch %d, voting_acc: %.4f, best_voting_acc: %.4f" %
              (epoch, acc, best_voting_acc))
    jd = {"test_acc": final_accs}
    with open(str(args.seed) + '/acc' + str(args.person) + '.json', 'w') as f:
        json.dump(jd, f)
    def train_models(self):

        self._load_checkpoint()

        # loop over the dataset multiple times
        for self.epoch_id in range(self.epoch_to_start, self.max_num_epochs):

            ################## train #################
            ##########################################
            self._clear_cache()
            self.is_training = True
            self.net_G.train()  # Set model to training mode
            self.net_D1.train()  # Set model to training mode
            self.net_D2.train()  # Set model to training mode
            self.net_D3.train()  # Set model to training mode
            # Iterate over data.
            for self.batch_id, batch in enumerate(self.dataloaders['train'],
                                                  0):
                self._forward_pass(batch)
                # update D1 and D2 and D3
                utils.set_requires_grad(self.net_D1, True)
                utils.set_requires_grad(self.net_D2, True)
                utils.set_requires_grad(self.net_D3, True)
                self.optimizer_D1.zero_grad()
                self.optimizer_D2.zero_grad()
                self.optimizer_D3.zero_grad()
                self._backward_D()
                self.optimizer_D1.step()
                self.optimizer_D2.step()
                self.optimizer_D3.step()
                # update G
                utils.set_requires_grad(self.net_D1, False)
                utils.set_requires_grad(self.net_D2, False)
                utils.set_requires_grad(self.net_D3, False)
                self.optimizer_G.zero_grad()
                self._backward_G()
                self.optimizer_G.step()
                self._collect_running_batch_states()
            self._collect_epoch_states()
            self._update_lr_schedulers()

            ################## Eval ##################
            ##########################################
            print('Begin evaluation...')
            self._clear_cache()
            self.is_training = False
            # Set model to evaluate mode
            self.net_G.eval()
            self.net_D1.eval()
            self.net_D2.eval()
            self.net_D3.eval()

            # Iterate over data.
            for self.batch_id, batch in enumerate(self.dataloaders['val'], 0):
                with torch.no_grad():
                    self._forward_pass(batch)
                self._collect_running_batch_states()
            self._collect_epoch_states()

            ########### Update_Checkpoints ###########
            ##########################################
            self._update_checkpoints()
Beispiel #24
0
	def _create_target_network(self): 
		if self._body_type == 'ff':
			self._target_net = Head(ffBody(self._obs_num_features_or_obs_in_channels, self._fc_hidden_layer_size), self._output_actions)
		else:
			self._target_net = Head(convBody(self._obs_num_features_or_obs_in_channels, self._fc_hidden_layer_size), self._output_actions)
		set_requires_grad(self._target_net, False)
Beispiel #25
0
        opt.lr = opt.lr * 0.1
        for param_group in g_optimizer.param_groups:
            param_group['lr'] = opt.lr
        for param_group in d_optimizer.param_groups:
            param_group['lr'] = opt.lr

    for i, (X, Y) in enumerate(dataloader):
        X, Y = X.to(device), Y.to(device)

        # --------train Generator-------#
        X_fake = F(Y)
        X_rec = G(X_fake)
        Y_fake = G(X)
        Y_rec = F(Y_fake)

        set_requires_grad([D_X, D_Y], False)
        g_optimizer.zero_grad()
        G_ad_loss = torch.mean((D_X(X_fake) - 1)**2)
        F_ad_loss = torch.mean((D_Y(Y_fake) - 1)**2)
        G_cyc_loss = torch.mean((X.detach() - X_rec)**2)
        F_cyc_loss = torch.mean((Y.detach() - Y_rec)**2)

        if opt.lambda_idt > 0:
            G_idt_loss = torch.mean(torch.abs(G(X) - X)) * opt.lambda_idt
            F_idt_loss = torch.mean(torch.abs(F(Y) - Y)) * opt.lambda_idt
        else:
            G_idt_loss = 0
            F_idt_loss = 0

        gen_loss = G_ad_loss + F_ad_loss + opt.lambda_ * (
            G_cyc_loss + F_cyc_loss + G_idt_loss + F_idt_loss)
Beispiel #26
0
elif mode == "c":
    training_models = [c_model]
elif mode == "c-a":
    training_models = [c_model, actor]
elif mode == "a":
    training_models = [actor]
elif mode == "e":
    training_models = []
else:
    raise NotImplementedError

### Define Parameters ###
training_params = []
for m in all_models:
    if m not in training_models:
        set_requires_grad(m, False)
        m.eval()
    else:
        training_params += list(m.parameters())
solver = None
if len(training_params) > 0:
    solver = optim.Adam(training_params, lr=1e-3)

### Visual Planning & Acting ###
if mode == "e":
    from eval import execute

    execute(all_models, kwargs["env"], savepath, eval_hp=kwargs["eval_hp"])
else:
    train(all_models, training_models, solver, training_params, log_every,
          **kwargs)
def train():

    # Fix Seed for Reproducibility #
    torch.manual_seed(9)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(9)

    # Samples, Weights, and Plots Path #
    paths = [config.samples_path, config.weights_path, config.plots_path]
    paths = [make_dirs(path) for path in paths]

    # Prepare Data Loader #
    train_loader_selfie, train_loader_anime = get_selfie2anime_loader(
        'train', config.batch_size)
    total_batch = max(len(train_loader_selfie), len(train_loader_anime))

    test_loader_selfie, test_loader_anime = get_selfie2anime_loader(
        'test', config.val_batch_size)

    # Prepare Networks #
    D_A = Discriminator(num_layers=7)
    D_B = Discriminator(num_layers=7)
    L_A = Discriminator(num_layers=5)
    L_B = Discriminator(num_layers=5)
    G_A2B = Generator(image_size=config.crop_size,
                      num_blocks=config.num_blocks)
    G_B2A = Generator(image_size=config.crop_size,
                      num_blocks=config.num_blocks)

    networks = [D_A, D_B, L_A, L_B, G_A2B, G_B2A]

    for network in networks:
        network.to(device)

    # Loss Function #
    Adversarial_loss = nn.MSELoss()
    Cycle_loss = nn.L1Loss()
    BCE_loss = nn.BCEWithLogitsLoss()

    # Optimizers #
    D_optim = torch.optim.Adam(chain(D_A.parameters(), D_B.parameters(),
                                     L_A.parameters(), L_B.parameters()),
                               lr=config.lr,
                               betas=(0.5, 0.999),
                               weight_decay=0.0001)
    G_optim = torch.optim.Adam(chain(G_A2B.parameters(), G_B2A.parameters()),
                               lr=config.lr,
                               betas=(0.5, 0.999),
                               weight_decay=0.0001)

    D_optim_scheduler = get_lr_scheduler(D_optim)
    G_optim_scheduler = get_lr_scheduler(G_optim)

    # Rho Clipper to constraint the value of rho in AdaILN and ILN #
    Rho_Clipper = RhoClipper(0, 1)

    # Lists #
    D_losses = []
    G_losses = []

    # Train #
    print("Training U-GAT-IT started with total epoch of {}.".format(
        config.num_epochs))
    for epoch in range(config.num_epochs):

        for i, (selfie, anime) in enumerate(
                zip(train_loader_selfie, train_loader_anime)):

            # Data Preparation #
            real_A = selfie.to(device)
            real_B = anime.to(device)

            # Initialize Optimizers #
            D_optim.zero_grad()
            G_optim.zero_grad()

            #######################
            # Train Discriminator #
            #######################

            set_requires_grad([D_A, D_B, L_A, L_B], requires_grad=True)

            # Forward Data #
            fake_B, _, _ = G_A2B(real_A)
            fake_A, _, _ = G_B2A(real_B)

            G_real_A, G_real_A_cam, _ = D_A(real_A)
            L_real_A, L_real_A_cam, _ = L_A(real_A)
            G_real_B, G_real_B_cam, _ = D_B(real_B)
            L_real_B, L_real_B_cam, _ = L_B(real_B)

            G_fake_A, G_fake_A_cam, _ = D_A(fake_A)
            L_fake_A, L_fake_A_cam, _ = L_A(fake_A)
            G_fake_B, G_fake_B_cam, _ = D_B(fake_B)
            L_fake_B, L_fake_B_cam, _ = L_B(fake_B)

            # Adversarial Loss of Discriminator #
            real_labels = torch.ones(G_real_A.shape).to(device)
            D_ad_real_loss_GA = Adversarial_loss(G_real_A, real_labels)

            fake_labels = torch.zeros(G_fake_A.shape).to(device)
            D_ad_fake_loss_GA = Adversarial_loss(G_fake_A, fake_labels)

            D_ad_loss_GA = D_ad_real_loss_GA + D_ad_fake_loss_GA

            real_labels = torch.ones(G_real_A_cam.shape).to(device)
            D_ad_cam_real_loss_GA = Adversarial_loss(G_real_A_cam, real_labels)

            fake_labels = torch.zeros(G_fake_A_cam.shape).to(device)
            D_ad_cam_fake_loss_GA = Adversarial_loss(G_fake_A_cam, fake_labels)

            D_ad_cam_loss_GA = D_ad_cam_real_loss_GA + D_ad_cam_fake_loss_GA

            real_labels = torch.ones(G_real_B.shape).to(device)
            D_ad_real_loss_GB = Adversarial_loss(G_real_B, real_labels)

            fake_labels = torch.zeros(G_fake_B.shape).to(device)
            D_ad_fake_loss_GB = Adversarial_loss(G_fake_B, fake_labels)

            D_ad_loss_GB = D_ad_real_loss_GB + D_ad_fake_loss_GB

            real_labels = torch.ones(G_real_B_cam.shape).to(device)
            D_ad_cam_real_loss_GB = Adversarial_loss(G_real_B_cam, real_labels)

            fake_labels = torch.zeros(G_fake_B_cam.shape).to(device)
            D_ad_cam_fake_loss_GB = Adversarial_loss(G_fake_B_cam, fake_labels)

            D_ad_cam_loss_GB = D_ad_cam_real_loss_GB + D_ad_cam_fake_loss_GB

            # Adversarial Loss of L #
            real_labels = torch.ones(L_real_A.shape).to(device)
            D_ad_real_loss_LA = Adversarial_loss(L_real_A, real_labels)

            fake_labels = torch.zeros(L_fake_A.shape).to(device)
            D_ad_fake_loss_LA = Adversarial_loss(L_fake_A, fake_labels)

            D_ad_loss_LA = D_ad_real_loss_LA + D_ad_fake_loss_LA

            real_labels = torch.ones(L_real_A_cam.shape).to(device)
            D_ad_cam_real_loss_LA = Adversarial_loss(L_real_A_cam, real_labels)

            fake_labels = torch.zeros(L_fake_A_cam.shape).to(device)
            D_ad_cam_fake_loss_LA = Adversarial_loss(L_fake_A_cam, fake_labels)

            D_ad_cam_loss_LA = D_ad_cam_real_loss_LA + D_ad_cam_fake_loss_LA

            real_labels = torch.ones(L_real_B.shape).to(device)
            D_ad_real_loss_LB = Adversarial_loss(L_real_B, real_labels)

            fake_labels = torch.zeros(L_fake_B.shape).to(device)
            D_ad_fake_loss_LB = Adversarial_loss(L_fake_B, fake_labels)

            D_ad_loss_LB = D_ad_real_loss_LB + D_ad_fake_loss_LB

            real_labels = torch.ones(L_real_B_cam.shape).to(device)
            D_ad_cam_real_loss_LB = Adversarial_loss(L_real_B_cam, real_labels)

            fake_labels = torch.zeros(L_fake_B_cam.shape).to(device)
            D_ad_cam_fake_loss_LB = Adversarial_loss(L_fake_B_cam, fake_labels)

            D_ad_cam_loss_LB = D_ad_cam_real_loss_LB + D_ad_cam_fake_loss_LB

            # Calculate Each Discriminator Loss #
            D_loss_A = D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA
            D_loss_B = D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB

            # Calculate Total Discriminator Loss #
            D_loss = D_loss_A + D_loss_B

            # Back Propagation and Update #
            D_loss.backward()
            D_optim.step()

            ###################
            # Train Generator #
            ###################

            set_requires_grad([D_A, D_B, L_A, L_B], requires_grad=False)

            # Forward Data #
            fake_B, fake_B_cam, _ = G_A2B(real_A)
            fake_A, fake_A_cam, _ = G_B2A(real_B)

            fake_ABA, _, _ = G_B2A(fake_B)
            fake_BAB, _, _ = G_A2B(fake_A)

            fake_A2A, fake_A2A_cam, _ = G_A2B(real_A)
            fake_B2B, fake_B2B_cam, _ = G_B2A(real_B)

            G_fake_A, G_fake_A_cam, _ = D_A(fake_A)
            L_fake_A, L_fake_A_cam, _ = L_A(fake_A)
            G_fake_B, G_fake_B_cam, _ = D_B(fake_B)
            L_fake_B, L_fake_B_cam, _ = L_B(fake_B)

            # Adversarial Loss of Generator #
            real_labels = torch.ones(G_fake_A.shape).to(device)
            G_adv_fake_loss_A = Adversarial_loss(G_fake_A, real_labels)

            real_labels = torch.ones(G_fake_A_cam.shape).to(device)
            G_adv_cam_fake_loss_A = Adversarial_loss(G_fake_A_cam, real_labels)

            G_adv_loss_A = G_adv_fake_loss_A + G_adv_cam_fake_loss_A

            real_labels = torch.ones(G_fake_B.shape).to(device)
            G_adv_fake_loss_B = Adversarial_loss(G_fake_B, real_labels)

            real_labels = torch.ones(G_fake_B_cam.shape).to(device)
            G_adv_cam_fake_loss_B = Adversarial_loss(G_fake_B_cam, real_labels)

            G_adv_loss_B = G_adv_fake_loss_B + G_adv_cam_fake_loss_B

            # Adversarial Loss of L #
            real_labels = torch.ones(L_fake_A.shape).to(device)
            L_adv_fake_loss_A = Adversarial_loss(L_fake_A, real_labels)

            real_labels = torch.ones(L_fake_A_cam.shape).to(device)
            L_adv_cam_fake_loss_A = Adversarial_loss(L_fake_A_cam, real_labels)

            L_adv_loss_A = L_adv_fake_loss_A + L_adv_cam_fake_loss_A

            real_labels = torch.ones(L_fake_B.shape).to(device)
            L_adv_fake_loss_B = Adversarial_loss(L_fake_B, real_labels)

            real_labels = torch.ones(L_fake_B_cam.shape).to(device)
            L_adv_cam_fake_loss_B = Adversarial_loss(L_fake_B_cam, real_labels)

            L_adv_loss_B = L_adv_fake_loss_B + L_adv_cam_fake_loss_B

            # Cycle Consistency Loss #
            G_recon_loss_A = Cycle_loss(fake_ABA, real_A)
            G_recon_loss_B = Cycle_loss(fake_BAB, real_B)

            G_identity_loss_A = Cycle_loss(fake_A2A, real_A)
            G_identity_loss_B = Cycle_loss(fake_B2B, real_B)

            G_cycle_loss_A = G_recon_loss_A + G_identity_loss_A
            G_cycle_loss_B = G_recon_loss_B + G_identity_loss_B

            # CAM Loss #
            real_labels = torch.ones(fake_A_cam.shape).to(device)
            G_cam_real_loss_A = BCE_loss(fake_A_cam, real_labels)

            fake_labels = torch.zeros(fake_A2A_cam.shape).to(device)
            G_cam_fake_loss_A = BCE_loss(fake_A2A_cam, fake_labels)

            G_cam_loss_A = G_cam_real_loss_A + G_cam_fake_loss_A

            real_labels = torch.ones(fake_B_cam.shape).to(device)
            G_cam_real_loss_B = BCE_loss(fake_B_cam, real_labels)

            fake_labels = torch.zeros(fake_B2B_cam.shape).to(device)
            G_cam_fake_loss_B = BCE_loss(fake_B2B_cam, fake_labels)

            G_cam_loss_B = G_cam_real_loss_B + G_cam_fake_loss_B

            # Calculate Each Generator Loss #
            G_loss_A = G_adv_loss_A + L_adv_loss_A + config.lambda_cycle * G_cycle_loss_A + config.lambda_cam * G_cam_loss_A
            G_loss_B = G_adv_loss_B + L_adv_loss_B + config.lambda_cycle * G_cycle_loss_B + config.lambda_cam * G_cam_loss_B

            # Calculate Total Generator Loss #
            G_loss = G_loss_A + G_loss_B

            # Back Propagation and Update #
            G_loss.backward()
            G_optim.step()

            # Apply Rho Clipper to Generators #
            G_A2B.apply(Rho_Clipper)
            G_B2A.apply(Rho_Clipper)

            # Add items to Lists #
            D_losses.append(D_loss.item())
            G_losses.append(G_loss.item())

            ####################
            # Print Statistics #
            ####################

            if (i + 1) % config.print_every == 0:
                print(
                    "U-GAT-IT | Epochs [{}/{}] | Iterations [{}/{}] | D Loss {:.4f} | G Loss {:.4f}"
                    .format(epoch + 1, config.num_epochs, i + 1, total_batch,
                            np.average(D_losses), np.average(G_losses)))

                # Save Sample Images #
                save_samples(test_loader_selfie, G_A2B, epoch,
                             config.samples_path)

        # Adjust Learning Rate #
        D_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Model Weights #
        if (epoch + 1) % config.save_every == 0:
            torch.save(
                D_A.state_dict(),
                os.path.join(config.weights_path,
                             'U-GAT-IT_D_A_Epoch_{}.pkl'.format(epoch + 1)))
            torch.save(
                D_B.state_dict(),
                os.path.join(config.weights_path,
                             'U-GAT-IT_D_B_Epoch_{}.pkl'.format(epoch + 1)))
            torch.save(
                L_A.state_dict(),
                os.path.join(config.weights_path,
                             'U-GAT-IT_L_A_Epoch_{}.pkl'.format(epoch + 1)))
            torch.save(
                L_B.state_dict(),
                os.path.join(config.weights_path,
                             'U-GAT-IT_L_B_Epoch_{}.pkl'.format(epoch + 1)))
            torch.save(
                G_A2B.state_dict(),
                os.path.join(config.weights_path,
                             'U-GAT-IT_G_A2B_Epoch_{}.pkl'.format(epoch + 1)))
            torch.save(
                G_B2A.state_dict(),
                os.path.join(config.weights_path,
                             'U-GAT-IT_G_B2A_Epoch_{}.pkl'.format(epoch + 1)))

    # Plot Losses #
    plot_losses(D_losses, G_losses, config.num_epochs, config.plots_path)

    # Make a GIF file #
    make_gifs_train('U-GAT-IT', config.samples_path)

    print("Training finished.")
def train_srgans(train_loader, val_loader, generator, discriminator, device,
                 args):

    # Loss Function #
    criterion_Perceptual = PerceptualLoss(args.model).to(device)

    # For SRGAN #
    criterion_MSE = nn.MSELoss()
    criterion_TV = TVLoss()

    # For ESRGAN #
    criterion_BCE = nn.BCEWithLogitsLoss()
    criterion_Content = nn.L1Loss()

    # Optimizers #
    D_optim = torch.optim.Adam(discriminator.parameters(),
                               lr=args.lr,
                               betas=(0.9, 0.999))
    G_optim = torch.optim.Adam(generator.parameters(),
                               lr=args.lr,
                               betas=(0.9, 0.999))

    D_optim_scheduler = get_lr_scheduler(D_optim, args)
    G_optim_scheduler = get_lr_scheduler(G_optim, args)

    # Lists #
    D_losses, G_losses = list(), list()

    # Train #
    print("Training {} started with total epoch of {}.".format(
        str(args.model).upper(), args.num_epochs))

    for epoch in range(args.num_epochs):
        for i, (high, low) in enumerate(train_loader):

            discriminator.train()
            if args.model == "srgan":
                generator.train()

            # Data Preparation #
            high = high.to(device)
            low = low.to(device)

            # Initialize Optimizers #
            D_optim.zero_grad()
            G_optim.zero_grad()

            #######################
            # Train Discriminator #
            #######################

            set_requires_grad(discriminator, requires_grad=True)

            # Generate Fake HR Images #
            fake_high = generator(low)

            if args.model == 'srgan':

                # Forward Data #
                prob_real = discriminator(high)
                prob_fake = discriminator(fake_high.detach())

                # Calculate Total Discriminator Loss #
                D_loss = 1 - prob_real.mean() + prob_fake.mean()

            elif args.model == 'esrgan':

                # Forward Data #
                prob_real = discriminator(high)
                prob_fake = discriminator(fake_high.detach())

                # Relativistic Discriminator #
                diff_r2f = prob_real - prob_fake.mean()
                diff_f2r = prob_fake - prob_real.mean()

                # Labels #
                real_labels = torch.ones(diff_r2f.size()).to(device)
                fake_labels = torch.zeros(diff_f2r.size()).to(device)

                # Adversarial Loss #
                D_loss_real = criterion_BCE(diff_r2f, real_labels)
                D_loss_fake = criterion_BCE(diff_f2r, fake_labels)

                # Calculate Total Discriminator Loss #
                D_loss = (D_loss_real + D_loss_fake).mean()

            # Back Propagation and Update #
            D_loss.backward()
            D_optim.step()

            ###################
            # Train Generator #
            ###################

            set_requires_grad(discriminator, requires_grad=False)

            if args.model == 'srgan':

                # Adversarial Loss #
                prob_fake = discriminator(fake_high).mean()
                G_loss_adversarial = torch.mean(1 - prob_fake)
                G_loss_mse = criterion_MSE(fake_high, high)

                # Perceptual Loss #
                lambda_perceptual = 6e-3
                G_loss_perceptual = criterion_Perceptual(fake_high, high)

                # Total Variation Loss #
                G_loss_tv = criterion_TV(fake_high)

                # Calculate Total Generator Loss #
                G_loss = args.lambda_adversarial * G_loss_adversarial + G_loss_mse + lambda_perceptual * G_loss_perceptual + args.lambda_tv * G_loss_tv

            elif args.model == 'esrgan':

                # Forward Data #
                prob_real = discriminator(high)
                prob_fake = discriminator(fake_high)

                # Relativistic Discriminator #
                diff_r2f = prob_real - prob_fake.mean()
                diff_f2r = prob_fake - prob_real.mean()

                # Labels #
                real_labels = torch.ones(diff_r2f.size()).to(device)
                fake_labels = torch.zeros(diff_f2r.size()).to(device)

                # Adversarial Loss #
                G_loss_bce_real = criterion_BCE(diff_f2r, real_labels)
                G_loss_bce_fake = criterion_BCE(diff_r2f, fake_labels)

                G_loss_bce = (G_loss_bce_real + G_loss_bce_fake).mean()

                # Perceptual Loss #
                lambda_perceptual = 1e-2
                G_loss_perceptual = criterion_Perceptual(fake_high, high)

                # Content Loss #
                G_loss_content = criterion_Content(fake_high, high)

                # Calculate Total Generator Loss #
                G_loss = args.lambda_bce * G_loss_bce + lambda_perceptual * G_loss_perceptual + args.lambda_content * G_loss_content

            # Back Propagation and Update #
            G_loss.backward()
            G_optim.step()

            # Add items to Lists #
            D_losses.append(D_loss.item())
            G_losses.append(G_loss.item())

            ####################
            # Print Statistics #
            ####################

            if (i + 1) % args.print_every == 0:
                print(
                    "{} | Epoch [{}/{}] | Iterations [{}/{}] | D Loss {:.4f} | G Loss {:.4f}"
                    .format(
                        str(args.model).upper(), epoch + 1, args.num_epochs,
                        i + 1, len(train_loader), np.average(D_losses),
                        np.average(G_losses)))

                # Save Sample Images #
                sample_images(val_loader, args.batch_size, args.scale_factor,
                              generator, epoch, args.samples_path, device)

        # Adjust Learning Rate #
        D_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Model Weights and Inference #
        if (epoch + 1) % args.save_every == 0:
            torch.save(
                generator.state_dict(),
                os.path.join(
                    args.weights_path,
                    '{}_Epoch_{}.pkl'.format(generator.__class__.__name__,
                                             epoch + 1)))
            inference(val_loader, generator, args.upscale_factor, epoch,
                      args.inference_path, device)
Beispiel #29
0
    def train(self, verbose=True):

        # set_detect_anomaly(True)
        torch.cuda.empty_cache()

        # load last iteration if training was started but not finished
        if len(listdir("saves7")) > 0:
            # the [4::-3] is because of the file name format, with the number of each checkpoint at these points
            self.loadModel(
                sorted(listdir('saves7'),
                       key=lambda x: int(x[4:-3]))[-1][4:-3])
            # self.loadModel(120)
            self.checkpoint = int(
                sorted(listdir('saves7'),
                       key=lambda x: int(x[4:-3]))[-1][4:-3])
            print("Loading from checkpoint: ", self.checkpoint)
            self.checkpoint = self.checkpoint + 1
            print("New checkpoint starts at: ", self.checkpoint)
        else:
            self.StyleGan.init_weights()
            self.checkpoint = 0
        # self.loadModel(1050)
        # self.checkpoint = 0
        # print(self.StyleGan)

        # utils.init_weights(self.StyleGan)
        # utils.set_requires_grad(self.StyleGan, True)

        # training loop
        for epoch in range(0, self.epochs):
            for batch_num, batch in enumerate(self.dataLoader):

                # if batch_num % 50 == 0:
                #     # generated_images = self.StyleGan.generator(self.constant_style, self.constant_noise)
                #     img_grid = make_grid(generated_images)
                #     self.tensorboard_summary.add_image(f'generated_image{self.checkpoint}', img_grid)
                #     del generated_images
                #     del img_grid

                batch = batch[0].expand(-1, 3, -1, -1).to(self.device)
                batch.requires_grad = True

                # print("OMG", batch.shape)
                if batch.shape[0] != 128:
                    print("SKIPPING")
                    continue

                w_space = []
                # Train Discriminator: maximize log(D(x)) + log(1 - D(G(z)))
                utils.set_requires_grad(self.StyleGan.discriminator, True)
                self.StyleGan.discriminator.train()
                self.StyleGan.generator.eval()
                utils.set_requires_grad(self.StyleGan.generator, False)

                if np.random.random() < self.mixed_probability:
                    style_noise = utils.createStyleMixedNoiseList(
                        self.batch_size, self.latent_dim, self.num_layers,
                        self.StyleGan.styleNetwork, self.device)
                else:
                    style_noise = utils.createStyleNoiseList(
                        self.batch_size, self.latent_dim, self.num_layers,
                        self.StyleGan.styleNetwork, self.device)
                image_noise = utils.create_image_noise(self.batch_size,
                                                       self.image_size,
                                                       self.device)

                # style_noise = style_noise.half()
                # image_noise = image_noise.half()

                self.StyleGan.discriminatorOptimizer.zero_grad()
                real_labels = (torch.ones(self.batch_size) * 0.9).to(
                    self.device)

                # transpose to match size with label tensor size
                discriminator_real_output = self.StyleGan.discriminator(
                    batch).reshape(-1).to(self.device)
                # print("DIS real output", discriminator_real_output)

                discriminator_real_loss = self.loss_fn(
                    discriminator_real_output, real_labels).mean()
                # print("DIS real loss", discriminator_real_loss)
                # print("DIS real labelsL", real_labels)

                del real_labels

                generated_images = self.StyleGan.generator(
                    style_noise.detach(), image_noise.detach()).to(self.device)
                del style_noise
                del image_noise
                fake_labels = (torch.ones(self.batch_size) * 0.1).to(
                    self.device)

                # transpose to match size with label tensor size
                discriminator_fake_output = self.StyleGan.discriminator(
                    generated_images.detach()).reshape(-1).to(self.device)
                # print("DIS fake output", discriminator_fake_output)
                discriminator_fake_loss = self.loss_fn(
                    discriminator_fake_output, fake_labels).mean()
                # print("DIS fake loss", discriminator_fake_loss)
                # print("DIS fake labels", fake_labels)

                del fake_labels
                del discriminator_fake_output

                discriminator_total_loss = discriminator_fake_loss + discriminator_real_loss
                # if batch_num % 100 == 0:
                #     print("d real loss", discriminator_real_loss)
                #     print("d fake loss", discriminator_fake_loss)
                #     print("d real output", discriminator_real_output)
                #     print("d fake output", discriminator_fake_output)
                #     print("real + fake loss", discriminator_fake_loss + discriminator_real_loss)
                #     print("total loss", discriminator_total_loss)
                #     print("\n\n")

                # discriminator_accuracy = 0

                # Apply Gradient Penalty every 4 steps
                # if batch_num % 4 == 0:
                # print("before gradien tpenalty", batch_num)
                discriminator_total_loss = discriminator_total_loss + utils.gradientPenalty(
                    batch, discriminator_real_output, self.device)

                del discriminator_real_output

                if isnan(discriminator_total_loss):
                    print("IS NAN discriminator")
                    break

                if self.apex_available:
                    with amp.scale_loss(discriminator_total_loss,
                                        self.StyleGan.discriminatorOptimizer
                                        ) as scaled_loss:
                        scaled_loss.backward()
                else:
                    discriminator_total_loss.backward()

                torch.nn.utils.clip_grad_norm_(
                    self.StyleGan.discriminator.parameters(), 5, norm_type=2)
                # for p in self.StyleGan.discriminator.parameters():
                #     p.data.clamp_(-0.01, 0.01)

                self.StyleGan.discriminatorOptimizer.step()
                # Train Generator: maximize log(D(G(z)))

                utils.set_requires_grad(self.StyleGan.discriminator, False)
                self.StyleGan.discriminator.eval()
                self.StyleGan.generator.train()
                utils.set_requires_grad(self.StyleGan.generator, True)
                if np.random.random() < self.mixed_probability:
                    style_noise = utils.createStyleMixedNoiseList(
                        self.batch_size, self.latent_dim, self.num_layers,
                        self.StyleGan.styleNetwork, self.device)
                else:
                    style_noise = utils.createStyleNoiseList(
                        self.batch_size, self.latent_dim, self.num_layers,
                        self.StyleGan.styleNetwork, self.device)
                image_noise = utils.create_image_noise(self.batch_size,
                                                       self.image_size,
                                                       self.device)

                self.StyleGan.generatorOptimizer.zero_grad()
                generator_labels = torch.ones(self.batch_size).to(self.device)

                # Generate images
                generated_images = self.StyleGan.generator(
                    style_noise, image_noise).to(self.device)
                del image_noise

                # utils.showImage(batch)
                # utils.showImage(generated_images)

                generator_output = self.StyleGan.discriminator(
                    generated_images).reshape(-1).to(self.device)
                # print("gen output", generator_output)
                generator_loss = self.loss_fn(generator_output,
                                              generator_labels).mean()
                # print("gen loss", generator_loss)
                # print("gen labels", generator_labels)

                # if batch_num % 100 == 0:
                #     print(generator_loss)

                if isnan(generator_loss):
                    print("isnan generator")
                    break

                del generator_output
                del generator_labels

                generator_loss_no_pl = generator_loss

                # Apply Path Length Regularization every 16 steps
                if batch_num % 10 == 0:
                    num_pixels = generated_images.shape[
                        2] * generated_images.shape[3]
                    noise_to_add = (torch.randn(generated_images.shape) /
                                    math.sqrt(num_pixels)).to(self.device)
                    outputs = (generated_images * noise_to_add)

                    # del generated_images
                    pl_gradient = grad(outputs=outputs,
                                       inputs=style_noise,
                                       grad_outputs=torch.ones(
                                           outputs.shape).to(self.device),
                                       create_graph=True,
                                       retain_graph=True,
                                       only_inputs=True)[0]
                    del num_pixels
                    del noise_to_add
                    del outputs

                    pl_length = torch.sqrt(torch.sum(
                        torch.square(pl_gradient)))

                    if self.average_pl_length is not None:
                        pl_regularizer = ((pl_length -
                                           self.average_pl_length)**2).mean()
                    else:
                        pl_regularizer = (pl_length**2).mean()

                    del pl_gradient

                    del style_noise

                    # print("PL LENGTH IS: ", pl_length)
                    if self.average_pl_length == None:
                        self.average_pl_length = pl_length.detach().item()
                    else:
                        self.average_pl_length = self.average_pl_length * self.pl_beta + (
                            1 - self.pl_beta) * pl_length.detach().item()
                    # self.average_pl_length = pl_length

                    del pl_length

                    generator_loss = generator_loss + pl_regularizer

                    # print(self.average_pl_length)
                    # print(batch_num)

                # print(batch_num, "HI")
                if self.apex_available:
                    with amp.scale_loss(
                            generator_loss,
                            self.StyleGan.generatorOptimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    generator_loss.backward(retain_graph=True)

                torch.nn.utils.clip_grad_norm_(
                    self.StyleGan.generator.parameters(), 5, norm_type=2)
                # for p in self.StyleGan.generator.parameters():
                #     p.data.clamp_(-0.01, 0.01)
                # generator_accuracy = generator_loss.argmax == generator_labels  # TODO
                self.StyleGan.generatorOptimizer.step()

                # Update MappingNetwork weights
                # if self.apex_available:
                #     with amp.scale_loss(generator_loss, self.StyleGan.generatorOptimizer) as scaled_loss:
                #         scaled_loss.backward()
                # else:
                #     generator_loss.backward(retain_graph = True)

                if verbose == True:
                    if batch_num % 100 == 0 and batch_num != 0:
                        # print("average path length is: ", self.average_pl_length)
                        print("Checkpoint")
                        print("Batch: ", batch_num)
                        print("Path Length Mean: ", self.average_pl_length)
                        print("Discriminator Mean Real Loss: ",
                              discriminator_real_loss.item())
                        print("Discriminator Mean Fake Loss: ",
                              discriminator_fake_loss.item())
                        print("Discriminator Total Loss: ",
                              discriminator_total_loss.item())
                        # print("Discriminator Accuracy: ", discriminator_accuracy)
                        print("Generator Loss (no pl)", generator_loss_no_pl)
                        print("Generator Loss: ", generator_loss.item())
                        # print("Generator Accuracy: ", generator_accuracy)
                        print("PL difference:", pl_regularizer.item())

                if batch_num % 100 == 0 and batch_num != 0:
                    print("Current Checkpoint is: ", self.checkpoint)
                    img_grid = make_grid(generated_images)

                    self.tensorboard_summary.add_scalar(
                        'Path Length Mean', self.average_pl_length,
                        self.checkpoint)
                    self.tensorboard_summary.add_scalar(
                        'Discriminator Mean Real Loss ',
                        discriminator_real_loss, self.checkpoint)
                    self.tensorboard_summary.add_scalar(
                        'Discriminator Mean Fake Loss ',
                        discriminator_fake_loss, self.checkpoint)
                    self.tensorboard_summary.add_scalar(
                        'Discriminator Total Loss ',
                        discriminator_total_loss.item(), self.checkpoint)
                    self.tensorboard_summary.add_scalar(
                        'Generator Loss', generator_loss.item(),
                        self.checkpoint)
                    self.tensorboard_summary.add_scalar(
                        'Path Length Difference', pl_regularizer.item(),
                        self.checkpoint)
                    self.tensorboard_summary.add_scalar(
                        'Generator Loss (No PL)', generator_loss_no_pl.item(),
                        self.checkpoint)
                    self.tensorboard_summary.add_image(
                        f'generated_image{self.checkpoint}', img_grid)
                    # self.tensorboard_summary.add_scalar("D")
                    del generated_images
                    del img_grid
                    # self.tensorboard_summary.add_scalar('Generator Weight', self.StyleGan.generator.we, self.checkpoint)
                    # self.tensorboard_summary.add_scalar('Generator Weight', generator_loss_no_pl.item(), self.checkpoint)
                    del generator_loss_no_pl
                    del discriminator_total_loss
                    del generator_loss
                    del pl_regularizer
                    del discriminator_real_loss
                    del discriminator_fake_loss
                    self.saveModel(self.checkpoint)
                    self.checkpoint = self.checkpoint + 1

                # if steps > 20000:
                #     self.StyleGan.EMA(0.99)

            # Right now, an epoch is never achieved
            # # Create a checkpoint at the end of an epoch
            # print("Current Checkpoint is: ", self.checkpoint)
            # self.tensorboard_summary.add_scalar('Path Length Mean', self.average_pl_length, self.checkpoint)
            # self.tensorboard_summary.add_scalar('Discriminator Mean Real Loss ',
            #                                     discriminator_real_loss, self.checkpoint)
            # self.tensorboard_summary.add_scalar('Discriminator Mean Fake Loss ',
            #                                     discriminator_fake_loss, self.checkpoint)
            # self.tensorboard_summary.add_scalar('Discriminator Total Loss ', discriminator_total_loss.item(),
            #                                     self.checkpoint)
            # self.tensorboard_summary.add_scalar('Generator Loss', generator_loss.item(), self.checkpoint)
            # self.tensorboard_summary.add_scalar('Path Length Difference', pl_regularizer.item(), self.checkpoint)
            # self.tensorboard_summary.add_scalar('Generator Loss (No PL)', generator_loss_no_pl.item())
            # del generator_loss_no_pl
            # del discriminator_total_loss
            # del generator_loss
            # del pl_regularizer
            # del discriminator_fake_loss
            # del discriminator_real_loss
            # self.saveModel(self.checkpoint)
            # self.checkpoint = self.checkpoint + 1

        # Close TensorBoard at the end
        self.tensorboard_summary.close()
Beispiel #30
0
    def step(self, i):

        self.G.train()

        scalars = defaultdict(list)
        xs_real = []
        xs_fake = []
        zs = []

        #############################################################
        # train D
        #############################################################

        set_requires_grad(self.D, True)

        self.optim_D.zero_grad(set_to_none=True)
        for j in gradient_accumulation(self.cfg.solver.num_accumulation, True,
                                       (self.G, self.D)):

            # input data
            x_real, m_real = self.fetch_reals(next(self.loader))
            xs_real.append({"depth": x_real, "mask": m_real})
            B = x_real.shape[0]

            # sample z
            z = self.sample_latents(B)
            zs.append(z)

            loss_D = 0

            # discriminator loss
            with torch.cuda.amp.autocast(enabled=self.enable_amp):
                synth = self.G(latent=z)
                xs_fake.append(synth)

                # augment
                x_real_aug = self.A(x_real).detach().requires_grad_()
                x_fake_aug = self.A(synth["depth"]).detach()

                # forward D
                y_real = self.D(x_real_aug)
                y_fake = self.D(x_fake_aug)

                scalars["loss/D/output/real"].append(y_real.mean().detach())
                scalars["loss/D/output/fake"].append(y_fake.mean().detach())

                # adversarial loss
                loss_GAN = self.criterion["gan"](y_real, y_fake, "D")
                loss_D += self.loss_weight["gan"] * loss_GAN

                scalars["loss/D/adversarial"].append(loss_GAN.detach())

            # r1 gradient penalty
            if "gp" in self.criterion:

                (grads, ) = torch.autograd.grad(
                    outputs=self.scaler.scale(y_real.sum()),
                    inputs=[x_real_aug],
                    create_graph=True,
                    only_inputs=True,
                )

                # unscale
                grads = grads / self.scaler.get_scale()

                with torch.cuda.amp.autocast(enabled=self.enable_amp):
                    r1_penalty = (grads**2).sum(dim=[1, 2, 3]).mean()
                    scalars["loss/D/gradient_penalty"].append(
                        r1_penalty.detach())
                    loss_D += (self.loss_weight["gp"] / 2) * r1_penalty
                    loss_D += 0.0 * y_real.squeeze()[0]

            loss_D /= float(self.cfg.solver.num_accumulation)
            self.scaler.scale(loss_D).backward()

        # update D parameters
        self.scaler.step(self.optim_D)

        #############################################################
        # train G
        #############################################################

        set_requires_grad(self.D, False)

        self.optim_G.zero_grad(set_to_none=True)
        for j in gradient_accumulation(self.cfg.solver.num_accumulation, True,
                                       (self.G, self.D)):
            loss_G = 0

            # generator loss
            with torch.cuda.amp.autocast(enabled=self.enable_amp):
                # augment
                x_real_aug = self.A(xs_real[j]["depth"]).detach()
                x_fake_aug = self.A(xs_fake[j]["depth"])

                # forward D
                y_real = self.D(x_real_aug)
                y_fake = self.D(x_fake_aug)

                # adversarial loss
                loss_GAN = self.criterion["gan"](y_real, y_fake, "G")
                loss_G += self.loss_weight["gan"] * loss_GAN

                scalars["loss/G/adversarial"].append(loss_GAN.detach())

            # path length regularization
            if "pl" in self.criterion:
                # forward G with smaller batch
                B_pl = len(xs_real[j]["depth"]) // 2
                z_pl = self.sample_latents(B_pl).requires_grad_()

                # perturb images
                with torch.cuda.amp.autocast(enabled=self.enable_amp):
                    synth_pl = self.G(latent=z_pl)
                    x_pl = synth_pl["depth"]
                    noise_pl = torch.randn_like(x_pl)
                    noise_pl /= np.sqrt(np.prod(x_pl.shape[2:]))
                    outputs = (x_pl * noise_pl).sum()

                (grads, ) = torch.autograd.grad(
                    outputs=self.scaler.scale(outputs),
                    inputs=[z_pl],
                    create_graph=True,
                    only_inputs=True,
                )

                # unscale
                grads = grads / self.scaler.get_scale()

                with torch.cuda.amp.autocast(enabled=self.enable_amp):
                    # compute |J*y|
                    pl_lengths = grads.pow(2).sum(dim=-1)
                    pl_lengths = torch.sqrt(pl_lengths)
                    # ema of |J*y|
                    pl_ema = self.pl_ema.lerp(pl_lengths.mean(), 0.01)
                    self.pl_ema.copy_(pl_ema.detach())
                    # calculate (|J*y|-a)^2
                    pl_penalty = (pl_lengths - pl_ema).pow(2).mean()

                    scalars["loss/G/path_length/baseline"].append(
                        self.pl_ema.detach())
                    scalars["loss/G/path_length"].append(pl_penalty.detach())

                    loss_G += self.loss_weight["pl"] * pl_penalty
                    loss_G += 0.0 * x_pl[0, 0, 0, 0]

            loss_G /= float(self.cfg.solver.num_accumulation)
            self.scaler.scale(loss_G).backward()

        # update G parameters
        self.scaler.step(self.optim_G)

        self.scaler.update()

        ema_inplace(self.G_ema, self.G.module, self.ema_decay)

        # gather scalars from all devices
        for key, scalar_list in scalars.items():
            scalar = torch.mean(torch.stack(scalar_list))
            dist.all_reduce(scalar)  # sum over gpus
            scalar /= dist.get_world_size()
            scalars[key] = scalar.item()

        return scalars