コード例 #1
0
class ALADTrainer:
    def __init__(self, args, data, device):
        self.args = args
        self.train_loader, _ = data
        self.device = device
        self.build_models()

    def train(self):
        """Training the ALAD"""

        if self.args.pretrained:
            self.load_weights()

        optimizer_ge = optim.Adam(list(self.G.parameters()) +
                                  list(self.E.parameters()),
                                  lr=self.args.lr,
                                  betas=(0.5, 0.999))
        params_ = list(self.Dxz.parameters()) \
                + list(self.Dzz.parameters()) \
                + list(self.Dxx.parameters())
        optimizer_d = optim.Adam(params_, lr=self.args.lr, betas=(0.5, 0.999))

        fixed_z = Variable(torch.randn((16, self.args.latent_dim, 1, 1)),
                           requires_grad=False).to(self.device)
        criterion = nn.BCELoss()
        for epoch in range(self.args.num_epochs + 1):
            ge_losses = 0
            d_losses = 0
            for x, _ in Bar(self.train_loader):
                #Defining labels
                y_true = Variable(torch.ones((x.size(0), 1)).to(self.device))
                y_fake = Variable(torch.zeros((x.size(0), 1)).to(self.device))

                #Cleaning gradients.
                optimizer_d.zero_grad()
                optimizer_ge.zero_grad()

                #Generator:
                z_real = Variable(torch.randn(
                    (x.size(0), self.args.latent_dim, 1, 1)).to(self.device),
                                  requires_grad=False)
                x_gen = self.G(z_real)

                #Encoder:
                x_real = x.float().to(self.device)
                z_gen = self.E(x_real)

                #Discriminatorxz
                out_truexz, _ = self.Dxz(x_real, z_gen)
                out_fakexz, _ = self.Dxz(x_gen, z_real)

                #Discriminatorzz
                out_truezz, _ = self.Dzz(z_real, z_real)
                out_fakezz, _ = self.Dzz(z_real, self.E(self.G(z_real)))

                #Discriminatorxx
                out_truexx, _ = self.Dxx(x_real, x_real)
                out_fakexx, _ = self.Dxx(x_real, self.G(self.E(x_real)))

                #Losses
                loss_dxz = criterion(out_truexz, y_true) + criterion(
                    out_fakexz, y_fake)
                loss_dzz = criterion(out_truezz, y_true) + criterion(
                    out_fakezz, y_fake)
                loss_dxx = criterion(out_truexx, y_true) + criterion(
                    out_fakexx, y_fake)
                loss_d = loss_dxz + loss_dzz + loss_dxx

                loss_gexz = criterion(out_fakexz, y_true) + criterion(
                    out_truexz, y_fake)
                loss_gezz = criterion(out_fakezz, y_true) + criterion(
                    out_truezz, y_fake)
                loss_gexx = criterion(out_fakexx, y_true) + criterion(
                    out_truexx, y_fake)
                cycle_consistency = loss_gezz + loss_gexx
                loss_ge = loss_gexz + loss_gezz + loss_gexx  # + cycle_consistency
                #Computing gradients and backpropagate.
                loss_d.backward(retain_graph=True)
                loss_ge.backward()
                optimizer_d.step()
                optimizer_ge.step()

                d_losses += loss_d.item()

                ge_losses += loss_ge.item()

            if epoch % 10 == 0:
                vutils.save_image((self.G(fixed_z).data + 1) / 2.,
                                  './images/{}_fake.png'.format(epoch))

            print(
                "Training... Epoch: {}, Discrimiantor Loss: {:.3f}, Generator Loss: {:.3f}"
                .format(epoch, d_losses / len(self.train_loader),
                        ge_losses / len(self.train_loader)))
        self.save_weights()

    def build_models(self):
        self.G = Generator(self.args.latent_dim).to(self.device)
        self.E = Encoder(self.args.latent_dim,
                         self.args.spec_norm).to(self.device)
        self.Dxz = Discriminatorxz(self.args.latent_dim,
                                   self.args.spec_norm).to(self.device)
        self.Dxx = Discriminatorxx(self.args.spec_norm).to(self.device)
        self.Dzz = Discriminatorzz(self.args.latent_dim,
                                   self.args.spec_norm).to(self.device)
        self.G.apply(weights_init_normal)
        self.E.apply(weights_init_normal)
        self.Dxz.apply(weights_init_normal)
        self.Dxx.apply(weights_init_normal)
        self.Dzz.apply(weights_init_normal)

    def save_weights(self):
        """Save weights."""
        state_dict_Dxz = self.Dxz.state_dict()
        state_dict_Dxx = self.Dxx.state_dict()
        state_dict_Dzz = self.Dzz.state_dict()
        state_dict_E = self.E.state_dict()
        state_dict_G = self.G.state_dict()
        torch.save(
            {
                'Generator': state_dict_G,
                'Encoder': state_dict_E,
                'Discriminatorxz': state_dict_Dxz,
                'Discriminatorxx': state_dict_Dxx,
                'Discriminatorzz': state_dict_Dzz
            },
            'weights/model_parameters_{}.pth'.format(self.args.normal_class))

    def load_weights(self):
        """Load weights."""
        state_dict = torch.load('weights/model_parameters.pth')

        self.Dxz.load_state_dict(state_dict['Discriminatorxz'])
        self.Dxx.load_state_dict(state_dict['Discriminatorxx'])
        self.Dzz.load_state_dict(state_dict['Discriminatorzz'])
        self.G.load_state_dict(state_dict['Generator'])
        self.E.load_state_dict(state_dict['Encoder'])
コード例 #2
0
ファイル: train.py プロジェクト: ameypatil10/CheXpert-CXR
def train(resume=False):

    it = 0

    writer = SummaryWriter('../runs/' + hparams.exp_name)

    for k in hparams.__dict__.keys():
        writer.add_text(str(k), str(hparams.__dict__[k]))

    train_dataset = ChestData(
        data_csv=hparams.train_csv,
        data_dir=hparams.train_dir,
        transform=transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize((0.485), (0.229))
        ]))

    validation_dataset = ChestData(
        data_csv=hparams.valid_csv,
        data_dir=hparams.valid_dir,
        transform=transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize((0.485), (0.229))
        ]))

    train_loader = DataLoader(train_dataset,
                              batch_size=hparams.batch_size,
                              shuffle=True,
                              num_workers=0)

    validation_loader = DataLoader(validation_dataset,
                                   batch_size=hparams.batch_size,
                                   shuffle=True,
                                   num_workers=0)

    print('loaded train data of length : {}'.format(len(train_dataset)))

    Tensor = torch.cuda.FloatTensor if hparams.cuda else torch.FloatTensor

    def validation(encoder_, decoder_=None, send_stats=False, epoch=0):
        encoder_ = encoder_.eval()
        if decoder_:
            decoder_ = decoder_.eval()
        # print('Validating model on {0} examples. '.format(len(validation_loader)))
        with torch.no_grad():
            scores_list = []
            labels_list = []
            val_loss = 0
            for (img, labels, imgs_names) in validation_loader:
                img = Variable(img.float(), requires_grad=False)
                labels = Variable(labels.float(), requires_grad=False)
                scores = None
                if hparams.cuda:
                    img = img.cuda(hparams.gpu_device)
                    labels = labels.cuda(hparams.gpu_device)

                z = encoder_(img)

                if decoder_:
                    outputs = decoder_(z)
                    scores = torch.sum(
                        (outputs - img)**2, dim=tuple(range(
                            1, outputs.dim())))  # (outputs - img) ** 2
                    # rec_loss = rec_loss.view(outputs.shape[0], -1)
                    # rec_loss = torch.sum(torch.sum(rec_loss, dim=1))
                    val_loss += torch.sum(scores)
                    save_image(img,
                               'tmp/img_{}.png'.format(epoch),
                               normalize=True)
                    save_image(outputs,
                               'tmp/reconstructed_{}.png'.format(epoch),
                               normalize=True)

                else:
                    dist = torch.sum((z - encoder.center)**2, dim=1)
                    if hparams.objective == 'soft-boundary':
                        scores = dist - encoder.radius**2
                        val_loss += (1 / hparams.nu) * torch.sum(
                            torch.max(torch.zeros_like(scores), scores))
                    else:
                        scores = dist
                        val_loss += torch.sum(dist)

                scores_list.append(scores)
                labels_list.append(labels)

            scores = torch.cat(scores_list, dim=0)
            labels = torch.cat(labels_list, dim=0)

            val_loss /= len(validation_dataset)
            val_loss += encoder_.radius**2 if decoder_ and hparams.objective == 'soft-boundary' else 0

            if hparams.cuda:
                labels = labels.cpu()
                scores = scores.cpu()

            labels = labels.view(-1).numpy()
            scores = scores.view(-1).detach().numpy()

            auc = roc_auc_score(labels, scores)

        return auc, val_loss

    ### validation function ends.

    if hparams.cuda:
        encoder = Encoder().cuda(hparams.gpu_device)
        decoder = Decoder().cuda(hparams.gpu_device)
    else:
        encoder = Encoder()
        decoder = Decoder()

    params_count = 0
    for param in encoder.parameters():
        params_count += np.prod(param.size())
    for param in decoder.parameters():
        params_count += np.prod(param.size())
    print('Model has {0} trainable parameters'.format(params_count))

    if not hparams.load_model:
        encoder.apply(weights_init_normal)
        decoder.apply(weights_init_normal)

    optim_params = list(encoder.parameters())
    optimizer_train = optim.Adam(optim_params,
                                 lr=hparams.train_lr,
                                 weight_decay=hparams.weight_decay,
                                 amsgrad=hparams.optimizer == 'amsgrad')

    if hparams.pretrain:
        optim_params += list(decoder.parameters())
        optimizer_pre = optim.Adam(optim_params,
                                   lr=hparams.pretrain_lr,
                                   weight_decay=hparams.ae_weight_decay,
                                   amsgrad=hparams.optimizer == 'amsgrad')
        # scheduler_pre = ReduceLROnPlateau(optimizer_pre, mode='min', factor=0.5, patience=10, verbose=True, cooldown=20)
        scheduler_pre = MultiStepLR(optimizer_pre,
                                    milestones=hparams.lr_milestones,
                                    gamma=0.1)

    # scheduler_train = ReduceLROnPlateau(optimizer_train, mode='min', factor=0.5, patience=10, verbose=True, cooldown=20)
    scheduler_train = MultiStepLR(optimizer_train,
                                  milestones=hparams.lr_milestones,
                                  gamma=0.1)

    print('Starting training.. (log saved in:{})'.format(hparams.exp_name))
    start_time = time.time()

    mode = 'pretrain' if hparams.pretrain else 'train'
    best_valid_loss = 100000000000000000
    best_valid_auc = 0
    encoder = init_center(encoder, train_loader)

    # print(model)
    for epoch in range(hparams.num_epochs):
        if mode == 'pretrain' and epoch == hparams.pretrain_epoch:
            print('Pretraining done.')
            mode = 'train'
            best_valid_loss = 100000000000000000
            best_valid_auc = 0
            encoder = init_center(encoder, train_loader)
        for batch, (imgs, labels, _) in enumerate(train_loader):

            # imgs = Variable(imgs.float(), requires_grad=False)

            if hparams.cuda:
                imgs = imgs.cuda(hparams.gpu_device)

            if mode == 'pretrain':
                optimizer_pre.zero_grad()
                z = encoder(imgs)
                outputs = decoder(z)
                # print(torch.max(outputs), torch.mean(imgs), torch.min(outputs), torch.mean(imgs))
                scores = torch.sum((outputs - imgs)**2,
                                   dim=tuple(range(1, outputs.dim())))
                # print(scores)
                loss = torch.mean(scores)
                loss.backward()
                optimizer_pre.step()
                writer.add_scalar('pretrain_loss',
                                  loss.item(),
                                  global_step=batch +
                                  len(train_loader) * epoch)

            else:
                optimizer_train.zero_grad()

                z = encoder(imgs)
                dist = torch.sum((z - encoder.center)**2, dim=1)
                if hparams.objective == 'soft-boundary':
                    scores = dist - encoder.radius**2
                    loss = encoder.radius**2 + (1 / hparams.nu) * torch.mean(
                        torch.max(torch.zeros_like(scores), scores))
                else:
                    loss = torch.mean(dist)

                loss.backward()
                optimizer_train.step()

                if hparams.objective == 'soft-boundary' and epoch >= hparams.warmup_epochs:
                    R = np.quantile(np.sqrt(dist.clone().data.cpu().numpy()),
                                    1 - hparams.nu)
                    encoder.radius = torch.tensor(R)
                    if hparams.cuda:
                        encoder.radius = encoder.radius.cuda(
                            hparams.gpu_device)
                    writer.add_scalar('radius',
                                      encoder.radius.item(),
                                      global_step=batch +
                                      len(train_loader) * epoch)
                writer.add_scalar('train_loss',
                                  loss.item(),
                                  global_step=batch +
                                  len(train_loader) * epoch)

            # pred_labels = (scores >= hparams.thresh)

            # save_image(imgs, 'train_imgs.png')
            # save_image(noisy_imgs, 'train_noisy.png')
            # save_image(gen_imgs, 'train_z.png')

            if batch % hparams.print_interval == 0:
                print('[Epoch - {0:.1f}, batch - {1:.3f}, loss - {2:.6f}]'.\
                format(1.0*epoch, 100.0*batch/len(train_loader), loss.item()))

        if mode == 'pretrain':
            val_auc, rec_loss = validation(copy.deepcopy(encoder),
                                           copy.deepcopy(decoder),
                                           epoch=epoch)
        else:
            val_auc, val_loss = validation(copy.deepcopy(encoder), epoch=epoch)

        writer.add_scalar('val_auc', val_auc, global_step=epoch)

        if mode == 'pretrain':
            best_valid_auc = max(best_valid_auc, val_auc)
            scheduler_pre.step()
            writer.add_scalar('rec_loss', rec_loss, global_step=epoch)
            writer.add_scalar('pretrain_lr',
                              optimizer_pre.param_groups[0]['lr'],
                              global_step=epoch)
            torch.save(
                {
                    'epoch': epoch,
                    'encoder_state_dict': encoder.state_dict(),
                    'decoder_state_dict': decoder.state_dict(),
                    'optimizer_pre_state_dict': optimizer_pre.state_dict(),
                }, hparams.model + '.pre')
            if best_valid_loss >= rec_loss:
                best_valid_loss = rec_loss
                torch.save(
                    {
                        'epoch': epoch,
                        'encoder_state_dict': encoder.state_dict(),
                        'decoder_state_dict': decoder.state_dict(),
                        'optimizer_pre_state_dict': optimizer_pre.state_dict(),
                    }, hparams.model + '.pre.best')
                print('best model on validation set saved.')
            print('[Epoch - {0:.1f} ---> rec_loss - {1:.4f}, current_lr - {2:.6f}, val_auc - {3:.4f}, best_valid_auc - {4:.4f}] - time - {5:.1f}'\
                .format(1.0*epoch, rec_loss, optimizer_pre.param_groups[0]['lr'], val_auc, best_valid_auc, time.time()-start_time))

        else:
            scheduler_train.step()
            writer.add_scalar('val_loss', val_loss, global_step=epoch)
            writer.add_scalar('train_lr',
                              optimizer_train.param_groups[0]['lr'],
                              global_step=epoch)
            torch.save(
                {
                    'epoch': epoch,
                    'encoder_state_dict': encoder.state_dict(),
                    'center': encoder.center,
                    'radius': encoder.radius,
                    'optimizer_train_state_dict': optimizer_train.state_dict(),
                }, hparams.model + '.train')
            if best_valid_loss >= val_loss:
                best_valid_loss = val_loss
                torch.save(
                    {
                        'epoch': epoch,
                        'encoder_state_dict': encoder.state_dict(),
                        'center': encoder.center,
                        'radius': encoder.radius,
                        'optimizer_train_state_dict':
                        optimizer_train.state_dict(),
                    }, hparams.model + '.train.best')
                print('best model on validation set saved.')
            if best_valid_auc <= val_auc:
                best_valid_auc = val_auc
                torch.save(
                    {
                        'epoch': epoch,
                        'encoder_state_dict': encoder.state_dict(),
                        'center': encoder.center,
                        'radius': encoder.radius,
                        'optimizer_train_state_dict':
                        optimizer_train.state_dict(),
                    }, hparams.model + '.train.auc')
                print('best model on validation set saved.')
            print('[Epoch - {0:.1f} ---> val_loss - {1:.4f}, current_lr - {2:.6f}, val_auc - {3:.4f}, best_valid_auc - {4:.4f}] - time - {5:.1f}'\
                .format(1.0*epoch, val_loss, optimizer_train.param_groups[0]['lr'], val_auc, best_valid_auc, time.time()-start_time))

        start_time = time.time()
コード例 #3
0
class EGBADTrainer:
    def __init__(self, args, data, device):
        self.args = args
        self.train_loader, _ = data
        self.device = device
        self.build_models()

    def train(self):
        """Training the AGBAD"""

        if self.args.pretrained:
            self.load_weights()

        optimizer_ge = optim.Adam(list(self.G.parameters()) +
                                  list(self.E.parameters()),
                                  lr=self.args.lr)
        optimizer_d = optim.Adam(self.D.parameters(), lr=self.args.lr)

        fixed_z = Variable(torch.randn((16, self.args.latent_dim, 1, 1)),
                           requires_grad=False).to(self.device)
        criterion = nn.BCELoss()
        for epoch in range(self.args.num_epochs + 1):
            ge_losses = 0
            d_losses = 0
            for x, _ in Bar(self.train_loader):
                #Defining labels
                y_true = Variable(torch.ones((x.size(0), 1)).to(self.device))
                y_fake = Variable(torch.zeros((x.size(0), 1)).to(self.device))

                #Noise for improving training.
                noise1 = Variable(torch.Tensor(x.size()).normal_(
                    0, 0.1 * (self.args.num_epochs - epoch) /
                    self.args.num_epochs),
                                  requires_grad=False).to(self.device)
                noise2 = Variable(torch.Tensor(x.size()).normal_(
                    0, 0.1 * (self.args.num_epochs - epoch) /
                    self.args.num_epochs),
                                  requires_grad=False).to(self.device)

                #Cleaning gradients.
                optimizer_d.zero_grad()
                optimizer_ge.zero_grad()

                #Generator:
                z_fake = Variable(torch.randn(
                    (x.size(0), self.args.latent_dim, 1, 1)).to(self.device),
                                  requires_grad=False)
                x_fake = self.G(z_fake)

                #Encoder:
                x_true = x.float().to(self.device)
                z_true = self.E(x_true)

                #Discriminator
                out_true = self.D(x_true + noise1, z_true)
                out_fake = self.D(x_fake + noise2, z_fake)

                #Losses
                loss_d = criterion(out_true, y_true) + criterion(
                    out_fake, y_fake)
                loss_ge = criterion(out_fake, y_true) + criterion(
                    out_true, y_fake)

                #Computing gradients and backpropagate.
                loss_d.backward(retain_graph=True)
                optimizer_d.step()

                loss_ge.backward()
                optimizer_ge.step()

                ge_losses += loss_ge.item()
                d_losses += loss_d.item()

            if epoch % 10 == 0:
                vutils.save_image((self.G(fixed_z).data + 1) / 2.,
                                  './images/{}_fake.png'.format(epoch))

            print(
                "Training... Epoch: {}, Discrimiantor Loss: {:.3f}, Generator Loss: {:.3f}"
                .format(epoch, d_losses / len(self.train_loader),
                        ge_losses / len(self.train_loader)))
        self.save_weights()

    def build_models(self):
        self.G = Generator(self.args.latent_dim).to(self.device)
        self.E = Encoder(self.args.latent_dim).to(self.device)
        self.D = Discriminator(self.args.latent_dim).to(self.device)
        self.G.apply(weights_init_normal)
        self.E.apply(weights_init_normal)
        self.D.apply(weights_init_normal)

    def save_weights(self):
        """Save weights."""
        state_dict_D = self.D.state_dict()
        state_dict_E = self.E.state_dict()
        state_dict_G = self.G.state_dict()
        torch.save(
            {
                'Generator': state_dict_G,
                'Encoder': state_dict_E,
                'Discriminator': state_dict_D
            }, 'weights/model_parameters.pth')

    def load_weights(self):
        """Load weights."""
        state_dict = torch.load('weights/model_parameters.pth')

        self.D.load_state_dict(state_dict['Discriminator'])
        self.G.load_state_dict(state_dict['Generator'])
        self.E.load_state_dict(state_dict['Encoder'])
コード例 #4
0
class TrainerBiGAN:
    def __init__(self, args, data, device):
        self.args = args
        self.train_loader = data
        self.device = device

    def train(self):
        """Training the BiGAN"""
        self.G = Generator(self.args.latent_dim).to(self.device)
        self.E = Encoder(self.args.latent_dim).to(self.device)
        self.D = Discriminator(self.args.latent_dim,
                               self.args.wasserstein).to(self.device)

        self.G.apply(weights_init_normal)
        self.E.apply(weights_init_normal)
        self.D.apply(weights_init_normal)

        if self.args.wasserstein:
            optimizer_ge = optim.RMSprop(list(self.G.parameters()) +
                                         list(self.E.parameters()),
                                         lr=self.args.lr_rmsprop)
            optimizer_d = optim.RMSprop(self.D.parameters(),
                                        lr=self.args.lr_rmsprop)
        else:
            optimizer_ge = optim.Adam(list(self.G.parameters()) +
                                      list(self.E.parameters()),
                                      lr=self.args.lr_adam)
            optimizer_d = optim.Adam(self.D.parameters(), lr=self.args.lr_adam)

        fixed_z = Variable(torch.randn((16, self.args.latent_dim, 1, 1)),
                           requires_grad=False).to(self.device)
        criterion = nn.BCELoss()
        for epoch in range(self.args.num_epochs + 1):
            ge_losses = 0
            d_losses = 0
            for x, _ in Bar(self.train_loader):
                #Defining labels
                y_true = Variable(torch.ones((x.size(0), 1)).to(self.device))
                y_fake = Variable(torch.zeros((x.size(0), 1)).to(self.device))

                #Noise for improving training.
                noise1 = Variable(torch.Tensor(x.size()).normal_(
                    0, 0.1 * (self.args.num_epochs - epoch) /
                    self.args.num_epochs),
                                  requires_grad=False).to(self.device)
                noise2 = Variable(torch.Tensor(x.size()).normal_(
                    0, 0.1 * (self.args.num_epochs - epoch) /
                    self.args.num_epochs),
                                  requires_grad=False).to(self.device)

                #Cleaning gradients.
                optimizer_d.zero_grad()
                optimizer_ge.zero_grad()

                #Generator:
                z_fake = Variable(torch.randn(
                    (x.size(0), self.args.latent_dim, 1, 1)).to(self.device),
                                  requires_grad=False)
                x_fake = self.G(z_fake)

                #Encoder:
                x_true = x.float().to(self.device)
                z_true = self.E(x_true)

                #Discriminator
                out_true = self.D(x_true + noise1, z_true)
                out_fake = self.D(x_fake + noise2, z_fake)

                #Losses
                if self.args.wasserstein:
                    loss_d = -torch.mean(out_true) + torch.mean(out_fake)
                    loss_ge = -torch.mean(out_fake) + torch.mean(out_true)
                else:
                    loss_d = criterion(out_true, y_true) + criterion(
                        out_fake, y_fake)
                    loss_ge = criterion(out_fake, y_true) + criterion(
                        out_true, y_fake)

                #Computing gradients and backpropagate.
                loss_d.backward(retain_graph=True)
                optimizer_d.step()

                loss_ge.backward()
                optimizer_ge.step()

                if self.args.wasserstein:
                    for p in self.D.parameters():
                        p.data.clamp_(-self.args.clamp, self.args.clamp)

                ge_losses += loss_ge.item()
                d_losses += loss_d.item()

            if epoch % 50 == 0:
                vutils.save_image(
                    self.G(fixed_z).data, './images/{}_fake.png'.format(epoch))

            print(
                "Training... Epoch: {}, Discrimiantor Loss: {:.3f}, Generator Loss: {:.3f}"
                .format(epoch, d_losses / len(self.train_loader),
                        ge_losses / len(self.train_loader)))
コード例 #5
0
class Trainer(object):
    def __init__(self, args):
        self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

        self.batch_size = args.batch_size
        self.half_size = self.batch_size // 2
        assert self.batch_size % 2 == 0, '[!] batch_size is '
        self.nz = args.nz

        self.lambda_kl = args.lambda_kl
        self.lambda_img = args.lambda_img
        self.lambda_z = args.lambda_z

        if args.img_size == 128:
            d_n_blocks = 2
            g_n_blocks = 7
            e_n_blocks = 4
        elif args.img_size == 256:
            d_n_blocks = 3
            g_n_blocks = 8
            e_n_blocks = 5

        # Discriminator for cVAE-GAN(encoded vector z)
        self.D_cVAE = Discriminator(args.input_nc + args.output_nc,
                                    args.ndf,
                                    n_blocks=d_n_blocks).to(self.device)
        self.D_cVAE.apply(weights_init)
        # print(self.D_cVAE)
        # Discriminator for cLR-GAN(random vector z)
        self.D_cLR = Discriminator(args.input_nc + args.output_nc,
                                   args.ndf,
                                   n_blocks=d_n_blocks).to(self.device)
        self.D_cLR.apply(weights_init)

        self.G = Generator(args.input_nc,
                           args.output_nc,
                           args.ngf,
                           args.nz,
                           n_blocks=g_n_blocks).to(self.device)
        self.G.apply(weights_init)
        # print(self.G)

        self.E = Encoder(args.input_nc, args.nz, args.nef,
                         n_blocks=e_n_blocks).to(self.device)
        self.E.apply(weights_init)
        # print(self.E)

        # Optimizers
        self.optim_D_cVAE = optim.Adam(self.D_cVAE.parameters(),
                                       lr=args.lr,
                                       betas=(args.beta1, args.beta2))
        self.optim_D_cLR = optim.Adam(self.D_cLR.parameters(),
                                      lr=args.lr,
                                      betas=(args.beta1, args.beta2))
        self.optim_G = optim.Adam(self.G.parameters(),
                                  lr=args.lr,
                                  betas=(args.beta1, args.beta2))
        self.optim_E = optim.Adam(self.E.parameters(),
                                  lr=args.lr,
                                  betas=(args.beta1, args.beta2))

        time_str = time.strftime("%Y%m%d-%H%M%S")
        self.writer = SummaryWriter('{}/{}-{}'.format(args.log_dir,
                                                      args.dataset_name,
                                                      time_str))

    def __del__(self):
        self.writer.close()

    def all_zero_grad(self):
        self.optim_D_cVAE.zero_grad()
        self.optim_D_cLR.zero_grad()
        self.optim_G.zero_grad()
        self.optim_E.zero_grad()

    def save_weights(self, save_dir, global_step):
        d_cVAE_name = 'D_cVAE_{}.pth'.format(global_step)
        d_cLR_name = 'D_cLR_{}.pth'.format(global_step)
        g_name = 'G_{}.pth'.format(global_step)
        e_name = 'E_{}.pth'.format(global_step)

        torch.save(self.D_cVAE.state_dict(),
                   os.path.join(save_dir, d_cVAE_name))
        torch.save(self.D_cLR.state_dict(), os.path.join(save_dir, d_cLR_name))
        torch.save(self.G.state_dict(), os.path.join(save_dir, g_name))
        torch.save(self.E.state_dict(), os.path.join(save_dir, e_name))

    def optimize(self, A, B, global_step):
        if A.size(0) <= 1:
            return

        A = A.to(self.device)
        B = B.to(self.device)

        cVAE_data = {'A': A[0:self.half_size], 'B': B[0:self.half_size]}
        cLR_data = {'A': A[self.half_size:], 'B': B[self.half_size:]}

        # Logging the input images
        log_imgs = torch.cat([cVAE_data['A'], cVAE_data['B']], 0)
        log_imgs = torchvision.utils.make_grid(log_imgs)
        log_imgs = denormalize(log_imgs)
        self.writer.add_image('cVAE_input', log_imgs, global_step)

        log_imgs = torch.cat([cLR_data['A'], cLR_data['B']], 0)
        log_imgs = torchvision.utils.make_grid(log_imgs)
        log_imgs = denormalize(log_imgs)
        self.writer.add_image('cLR_input', log_imgs, global_step)

        # ----------------------------------------------------------------
        # 1. Train D
        # ----------------------------------------------------------------

        # -----------------------------
        # Optimize D in cVAE-GAN
        # -----------------------------
        # Generate encoded latent vector
        mu, logvar = self.E(cVAE_data['B'])
        std = torch.exp(logvar / 2)
        random_z = sample_z(self.half_size, self.nz, 'gauss').to(self.device)
        encoded_z = (random_z * std) + mu

        # Generate fake image
        fake_img_cVAE = self.G(cVAE_data['A'], encoded_z)
        log_imgs = torchvision.utils.make_grid(fake_img_cVAE)
        log_imgs = denormalize(log_imgs)
        self.writer.add_image('cVAE_fake_encoded', log_imgs, global_step)

        real_pair_cVAE = torch.cat([cVAE_data['A'], cVAE_data['B']], dim=1)
        fake_pair_cVAE = torch.cat([cVAE_data['A'], fake_img_cVAE], dim=1)

        real_D_cVAE_1, real_D_cVAE_2 = self.D_cVAE(real_pair_cVAE)
        fake_D_cVAE_1, fake_D_cVAE_2 = self.D_cVAE(fake_pair_cVAE.detach())

        # The loss for small patch & big patch
        loss_D_cVAE_1 = mse_loss(real_D_cVAE_1, target=1) + mse_loss(
            fake_D_cVAE_1, target=0)
        loss_D_cVAE_2 = mse_loss(real_D_cVAE_2, target=1) + mse_loss(
            fake_D_cVAE_2, target=0)

        self.writer.add_scalar('loss/loss_D_cVAE_1', loss_D_cVAE_1.item(),
                               global_step)
        self.writer.add_scalar('loss/loss_D_cVAE_2', loss_D_cVAE_2.item(),
                               global_step)

        # -----------------------------
        # Optimize D in cLR-GAN
        # -----------------------------
        # Generate fake image
        fake_img_cLR = self.G(cLR_data['A'], random_z)
        log_imgs = torchvision.utils.make_grid(fake_img_cLR)
        log_imgs = denormalize(log_imgs)
        self.writer.add_image('cLR_fake_random', log_imgs, global_step)

        real_pair_cLR = torch.cat([cLR_data['A'], cLR_data['B']], dim=1)
        fake_pair_cLR = torch.cat([cVAE_data['A'], fake_img_cLR], dim=1)

        real_D_cLR_1, real_D_cLR_2 = self.D_cLR(real_pair_cLR)
        fake_D_cLR_1, fake_D_cLR_2 = self.D_cLR(fake_pair_cLR.detach())

        # Loss for small patch & big patch
        loss_D_cLR_1 = mse_loss(real_D_cLR_1, target=1) + mse_loss(
            fake_D_cLR_1, target=0)
        loss_D_cLR_2 = mse_loss(real_D_cLR_2, target=1) + mse_loss(
            fake_D_cLR_2, target=0)

        self.writer.add_scalar('loss/loss_D_cVAE_1', loss_D_cVAE_1.item(),
                               global_step)
        self.writer.add_scalar('loss/loss_D_cVAE_2', loss_D_cVAE_2.item(),
                               global_step)

        loss_D = loss_D_cVAE_1 + loss_D_cVAE_2 + loss_D_cLR_1 + loss_D_cLR_2
        self.writer.add_scalar('loss/loss_D', loss_D.item(), global_step)

        # -----------------------------
        # Update D
        # -----------------------------
        # set_requires_grad([], False)
        self.all_zero_grad()
        loss_D.backward()
        self.optim_D_cVAE.step()
        self.optim_D_cLR.step()

        # ----------------------------------------------------------------
        # 2. Train G & E
        # ----------------------------------------------------------------

        # -----------------------------
        # GAN loss
        # -----------------------------
        # Generate encoded latent vector
        mu, logvar = self.E(cVAE_data['B'])
        std = torch.exp(logvar / 2)
        random_z = sample_z(self.half_size, self.nz, 'gauss').to(self.device)
        encoded_z = (random_z * std) + mu

        # Generate fake image
        fake_img_cVAE = self.G(cVAE_data['A'], encoded_z)
        # self.writer.add_images('cVAE_output', fake_img_cVAE.add(1.0).mul(0.5), global_step)
        fake_pair_cVAE = torch.cat([cVAE_data['A'], fake_img_cVAE], dim=1)

        # Fool D_cVAE
        fake_D_cVAE_1, fake_D_cVAE_2 = self.D_cVAE(fake_pair_cVAE)

        # Loss for small patch & big patch
        loss_G_cVAE_1 = mse_loss(fake_D_cVAE_1, target=1)
        loss_G_cVAE_2 = mse_loss(fake_D_cVAE_2, target=1)

        # Random latent vector and generate fake image
        random_z = sample_z(self.half_size, self.nz, 'gauss').to(self.device)
        fake_img_cLR = self.G(cLR_data['A'], random_z)
        fake_pair_cLR = torch.cat([cLR_data['A'], fake_img_cLR], dim=1)

        # Fool D_cLR
        fake_D_cLR_1, fake_D_cLR_2 = self.D_cLR(fake_pair_cLR)

        # Loss for small patch & big patch
        loss_G_cLR_1 = mse_loss(fake_D_cLR_1, target=1)
        loss_G_cLR_2 = mse_loss(fake_D_cLR_2, target=1)

        loss_G = loss_G_cVAE_1 + loss_G_cVAE_2 + loss_G_cLR_1 + loss_G_cLR_2
        self.writer.add_scalar('loss/loss_G', loss_G.item(), global_step)

        # -----------------------------
        # KL-divergence (cVAE-GAN)
        # -----------------------------
        kl_div = torch.sum(
            0.5 * (mu**2 + torch.exp(logvar) - logvar - 1)) * self.lambda_kl
        self.writer.add_scalar('loss/kl_div', kl_div.item(), global_step)

        # -----------------------------
        # Reconstruction of image B (|G(A, z) - B|) (cVAE-GAN)
        # -----------------------------
        loss_img_recon = l1_loss(fake_img_cVAE,
                                 cVAE_data['B']) * self.lambda_img
        self.writer.add_scalar('loss/loss_img_recon', loss_img_recon.item(),
                               global_step)

        loss_E_G = loss_G + kl_div + loss_img_recon
        self.writer.add_scalar('loss/loss_E_G', loss_E_G.item(), global_step)

        # -----------------------------
        # Update E & G
        # -----------------------------
        self.all_zero_grad()
        loss_E_G.backward(retain_graph=True)
        self.optim_E.step()
        self.optim_G.step()

        # ----------------------------------------------------------------
        # 3. Train only G
        # ----------------------------------------------------------------

        # -----------------------------
        # Reconstruction of random latent code (|E(G(A, z)) - z|) (cLR-GAN)
        # -----------------------------
        # This step should update only G.
        # See https://github.com/junyanz/BicycleGAN/issues/5 for details.
        mu, logvar = self.E(fake_img_cLR)

        loss_z_recon = l1_loss(mu, random_z) * self.lambda_z
        self.writer.add_scalar('loss/loss_z_recon', loss_z_recon.item(),
                               global_step)

        # -----------------------------
        # Update G
        # -----------------------------
        self.all_zero_grad()
        loss_z_recon.backward()
        self.optim_G.step()
コード例 #6
0
ファイル: train.py プロジェクト: AlexFrontXQ/GAN-Pytorch
def train():
    opt = parse_args()

    os.makedirs("images/%s" % (opt.dataset), exist_ok=True)
    os.makedirs("checkpoints/%s" % (opt.dataset), exist_ok=True)

    cuda = True if torch.cuda.is_available() else False
    FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    # get dataloader
    train_loader = commic2human_loader(opt, mode='train')
    test_loader = commic2human_loader(opt, mode='test')

    # Dimensionality
    input_shape = (opt.channels, opt.img_height, opt.img_width)
    shared_dim = opt.dim * (2**opt.n_downsample)

    # Initialize generator and discriminator
    shared_E = ResidualBlock(in_channels=shared_dim)
    E1 = Encoder(dim=opt.dim,
                 n_downsample=opt.n_downsample,
                 shared_block=shared_E)
    E2 = Encoder(dim=opt.dim,
                 n_downsample=opt.n_downsample,
                 shared_block=shared_E)

    shared_G = ResidualBlock(in_channels=shared_dim)
    G1 = Generator(dim=opt.dim,
                   n_upsample=opt.n_upsample,
                   shared_block=shared_G)
    G2 = Generator(dim=opt.dim,
                   n_upsample=opt.n_upsample,
                   shared_block=shared_G)

    D1 = Discriminator(input_shape)
    D2 = Discriminator(input_shape)

    # Initialize weights
    E1.apply(weights_init_normal)
    E2.apply(weights_init_normal)
    G1.apply(weights_init_normal)
    G2.apply(weights_init_normal)
    D1.apply(weights_init_normal)
    D2.apply(weights_init_normal)

    # Loss function
    adversarial_loss = torch.nn.MSELoss()
    pixel_loss = torch.nn.L1Loss()

    if cuda:
        E1 = E1.cuda()
        E2 = E2.cuda()
        G1 = G1.cuda()
        G2 = G2.cuda()
        D1 = D1.cuda()
        D2 = D2.cuda()
        adversarial_loss = adversarial_loss.cuda()
        pixel_loss = pixel_loss.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(itertools.chain(E1.parameters(),
                                                   E2.parameters(),
                                                   G1.parameters(),
                                                   G2.parameters()),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))
    optimizer_D1 = torch.optim.Adam(D1.parameters(),
                                    lr=opt.lr,
                                    betas=(opt.b1, opt.b2))
    optimizer_D2 = torch.optim.Adam(D2.parameters(),
                                    lr=opt.lr,
                                    betas=(opt.b1, opt.b2))

    # Learning rate update schedulers
    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
        optimizer_G, lr_lambda=LambdaLR(opt.epochs, 0, opt.decay_epoch).step)
    lr_scheduler_D1 = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D1, lr_lambda=LambdaLR(opt.epochs, 0, opt.decay_epoch).step)
    lr_scheduler_D2 = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D2, lr_lambda=LambdaLR(opt.epochs, 0, opt.decay_epoch).step)

    prev_time = time.time()
    for epoch in range(opt.epochs):
        for i, (img_A, img_B) in enumerate(train_loader):

            # Model inputs
            X1 = Variable(img_A.type(FloatTensor))
            X2 = Variable(img_B.type(FloatTensor))

            # Adversarial ground truths
            valid = Variable(FloatTensor(img_A.shape[0],
                                         *D1.output_shape).fill_(1.0),
                             requires_grad=False)
            fake = Variable(FloatTensor(img_A.shape[0],
                                        *D1.output_shape).fill_(0.0),
                            requires_grad=False)

            # -----------------------------
            # Train Encoders and Generators
            # -----------------------------

            # Get shared latent representation
            mu1, Z1 = E1(X1)
            mu2, Z2 = E2(X2)

            # Reconstruct images
            recon_X1 = G1(Z1)
            recon_X2 = G2(Z2)

            # Translate images
            fake_X1 = G1(Z2)
            fake_X2 = G2(Z1)

            # Cycle translation
            mu1_, Z1_ = E1(fake_X1)
            mu2_, Z2_ = E2(fake_X2)
            cycle_X1 = G1(Z2_)
            cycle_X2 = G2(Z1_)

            # Losses for encoder and generator
            id_loss_1 = opt.lambda_id * pixel_loss(recon_X1, X1)
            id_loss_2 = opt.lambda_id * pixel_loss(recon_X2, X2)

            adv_loss_1 = opt.lambda_adv * adversarial_loss(D1(fake_X1), valid)
            adv_loss_2 = opt.lambda_adv * adversarial_loss(D2(fake_X2), valid)

            cyc_loss_1 = opt.lambda_cyc * pixel_loss(cycle_X1, X1)
            cyc_loss_2 = opt.lambda_cyc * pixel_loss(cycle_X2, X2)

            KL_loss_1 = opt.lambda_KL1 * compute_KL(mu1)
            KL_loss_2 = opt.lambda_KL1 * compute_KL(mu2)
            KL_loss_1_ = opt.lambda_KL2 * compute_KL(mu1_)
            KL_loss_2_ = opt.lambda_KL2 * compute_KL(mu2_)

            # total loss for encoder and generator
            G_loss = id_loss_1 + id_loss_2 \
                     + adv_loss_1 + adv_loss_2 \
                     + cyc_loss_1 + cyc_loss_2 + \
                     KL_loss_1 + KL_loss_2 + KL_loss_1_ + KL_loss_2_

            G_loss.backward()
            optimizer_G.step()

            # ----------------------
            # Train Discriminator 1
            # ----------------------

            optimizer_D1.zero_grad()

            D1_loss = adversarial_loss(D1(X1), valid) + adversarial_loss(
                D1(fake_X1.detach()), fake)
            D1_loss.backward()

            optimizer_D1.step()

            # ----------------------
            # Train Discriminator 2
            # ----------------------

            optimizer_D2.zero_grad()

            D2_loss = adversarial_loss(D2(X2), valid) + adversarial_loss(
                D2(fake_X2.detach()), fake)
            D2_loss.backward()

            optimizer_D2.step()

            # ------------------
            # Log Information
            # ------------------

            batches_done = epoch * len(train_loader) + i
            batches_left = opt.epochs * len(train_loader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] ETA: %s"
                % (epoch, opt.epochs, i, len(train_loader),
                   (D1_loss + D2_loss).item(), G_loss.item(), time_left))

            if batches_done % opt.sample_interval == 0:
                save_sample(opt.dataset, test_loader, batches_done, E1, E2, G1,
                            G2, FloatTensor)

            if batches_done % opt.checkpoint_interval == 0:
                torch.save(E1.state_dict(),
                           "checkpoints/%s/E1_%d.pth" % (opt.dataset, epoch))
                torch.save(E2.state_dict(),
                           "checkpoints/%s/E2_%d.pth" % (opt.dataset, epoch))
                torch.save(G1.state_dict(),
                           "checkpoints/%s/G1_%d.pth" % (opt.dataset, epoch))
                torch.save(G2.state_dict(),
                           "checkpoints/%s/G2_%d.pth" % (opt.dataset, epoch))

        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D1.step()
        lr_scheduler_D2.step()

    torch.save(shared_E.state_dict(),
               "checkpoints/%s/shared_E_done.pth" % opt.dataset)
    torch.save(shared_G.state_dict(),
               "checkpoints/%s/shared_G_done.pth" % opt.dataset)
    torch.save(E1.state_dict(), "checkpoints/%s/E1_done.pth" % opt.dataset)
    torch.save(E2.state_dict(), "checkpoints/%s/E2_done.pth" % opt.dataset)
    torch.save(G1.state_dict(), "checkpoints/%s/G1_done.pth" % opt.dataset)
    torch.save(G2.state_dict(), "checkpoints/%s/G2_done.pth" % opt.dataset)
    print("Training Process has been Done!")
コード例 #7
0
###### Definition of variables ######
# Networks
encoder = Encoder(opt.input_nc)
decoder = Decoder()
# netD = Discriminator(opt.input_nc)
netD = MultiscaleDiscriminator(opt.input_nc, opt.ndf, opt.n_layers_D, norm_layer=nn.InstanceNorm2d, use_sigmoid=False, num_D=1, getIntermFeat=False)   
# transformer=transformer_block()


if opt.cuda:
    encoder.cuda()
    decoder.cuda()
    netD.cuda()
    # transformer.cuda()

encoder.apply(weights_init_normal)
decoder.apply(weights_init_normal)
netD.apply(weights_init_normal)


# Lossess
criterion_GAN = torch.nn.MSELoss()
criterion_l1 = torch.nn.L1Loss()
criterion_feat = torch.nn.MSELoss()
criterion_VGG= VGGLoss()

# Optimizers & LR schedulers
optimizer_encoder = torch.optim.Adam(encoder.parameters(),lr=opt.lr, betas=(0.5, 0.999))
optimizer_decoder = torch.optim.Adam(decoder.parameters(),lr=opt.lr, betas=(0.5, 0.999))

optimizer_D = torch.optim.Adam(netD.parameters(), lr=opt.lr, betas=(0.5, 0.999))
コード例 #8
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument(
        '--test-batch-size',
        type=int,
        default=10000,
        metavar='N',
        help='input batch size for reconstruction testing (default: 10,000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=20,
                        metavar='N',
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        metavar='LR',
                        help='learning rate (default: 0.001)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument(
        '--store-interval',
        type=int,
        default=100,
        metavar='N',
        help='how many batches to wait before storing training loss')

    args = parser.parse_args()

    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    # Set up dataloaders
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=True,
        download=True,
        transform=transforms.Compose([transforms.ToTensor()])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)

    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=False,
        transform=transforms.Compose([transforms.ToTensor()])),
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              **kwargs)

    train_loader_eval = torch.utils.data.DataLoader(
        datasets.MNIST('../data',
                       train=True,
                       transform=transforms.Compose([transforms.ToTensor()])),
        batch_size=args.test_batch_size,
        shuffle=True,
        **{})

    # Init model and optimizer
    model = Encoder(device).to(device)
    #Initialise weights and train
    path = "./output"

    #Initialise weights
    model.apply(weights_init)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    #Get rotation loss in t
    # rotation_test_loss=rotation_test(args, model_encoder, 'cpu', test_loader_disc)
    rotation_test_loss = []
    train_loss = []
    test_loss = []

    # Where the magic happens
    for epoch in range(1, args.epochs + 1):
        for batch_idx, (data, targets) in enumerate(train_loader):
            model.train()
            # Reshape data
            targets, angles = rotate_tensor(data.numpy())
            targets = torch.from_numpy(targets).to(device)
            angles = torch.from_numpy(angles).to(device)
            angles = angles.view(angles.size(0), 1)

            # Forward passes
            data = data.to(device)
            optimizer.zero_grad()
            f_data = model(data)  # [N,2,1,1]
            f_targets = model(targets)  #[N,2,1,1]

            #Apply rotatin matrix to f_data with feature transformer
            f_data_trasformed = feature_transformer(f_data, angles, device)

            #Define Loss
            forb_distance = torch.nn.PairwiseDistance()
            loss = (forb_distance(f_data_trasformed.view(-1, 2),
                                  f_targets.view(-1, 2))**2).sum()

            # Backprop
            loss.backward()
            optimizer.step()

            #Log progress
            if batch_idx % args.log_interval == 0:
                sys.stdout.write(
                    'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\r'.format(
                        epoch, batch_idx * len(data),
                        len(train_loader.dataset),
                        100. * batch_idx / len(train_loader), loss))
                sys.stdout.flush()

            #Store training and test loss
            if batch_idx % args.store_interval == 0:
                #Train Lossq
                train_loss.append(
                    evaluate_model(model, device, train_loader_eval))

                #Test Loss
                test_loss.append(evaluate_model(model, device, test_loader))

                #Rotation loss
                rotation_test_loss.append(
                    rotation_test(model, device, test_loader))

    #Save model
    save_model(args, model)
    #Save losses
    train_loss = np.array(train_loss)
    test_loss = np.array(test_loss)
    rotation_test_loss = np.array(rotation_test_loss)

    np.save(path + '/training_loss', train_loss)
    np.save(path + '/test_loss', test_loss)
    np.save(path + '/rotation_test_loss', rotation_test_loss)

    plot_learning_curve(args, train_loss, test_loss, rotation_test_loss)
コード例 #9
0
def main():
    # Training settings
    list_of_choices = ['forbenius', 'cosine_squared', 'cosine_abs']

    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch  rotation test (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=20,
                        metavar='N',
                        help='number of epochs to train (default: 20)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0001,
                        metavar='LR',
                        help='learning rate (default: 0.0001)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument(
        '--store-interval',
        type=int,
        default=50,
        metavar='N',
        help='how many batches to wait before storing training loss')
    parser.add_argument(
        '--name',
        type=str,
        default='',
        help='name of the run that is added to the output directory')
    parser.add_argument(
        "--loss",
        dest='loss',
        default='forbenius',
        choices=list_of_choices,
        help=
        'Decide type of loss, (forbenius) norm, difference of (cosine), (default=forbenius)'
    )
    parser.add_argument(
        '--init-rot-range',
        type=float,
        default=360,
        help=
        'Upper bound of range in degrees of initial random rotation of digits, (Default=360)'
    )
    parser.add_argument('--relative-rot-range',
                        type=float,
                        default=90,
                        metavar='theta',
                        help='Relative rotation range (-theta, theta)')
    parser.add_argument('--eval-batch-size',
                        type=int,
                        default=200,
                        metavar='N',
                        help='batch-size for evaluation')

    args = parser.parse_args()

    #Print arguments
    for arg in vars(args):
        sys.stdout.write('{} = {} \n'.format(arg, getattr(args, arg)))
        sys.stdout.flush()

    sys.stdout.write('Random torch seed:{}\n'.format(torch.initial_seed()))
    sys.stdout.flush()

    args.init_rot_range = args.init_rot_range * np.pi / 180
    args.relative_rot_range = args.relative_rot_range * np.pi / 180
    # Create save path

    path = "./output_" + args.name
    if not os.path.exists(path):
        os.makedirs(path)

    sys.stdout.write('Start training\n')
    sys.stdout.flush()

    use_cuda = torch.cuda.is_available()

    device = torch.device("cuda" if use_cuda else "cpu")

    writer = SummaryWriter(path, comment='Encoder atan2 MNIST')
    # Set up dataloaders
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=True,
        download=True,
        transform=transforms.Compose([transforms.ToTensor()])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)

    train_loader_eval = torch.utils.data.DataLoader(
        datasets.MNIST('../data',
                       train=True,
                       transform=transforms.Compose([transforms.ToTensor()])),
        batch_size=args.test_batch_size,
        shuffle=True,
        **kwargs)

    # Init model and optimizer
    model = Encoder(device).to(device)

    #Initialise weights
    model.apply(weights_init)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    #Init losses log

    prediction_mean_error = []  #Average  rotation prediction error in degrees
    prediction_error_std = []  #Std of error for rotation prediciton
    train_loss = []

    #Train
    n_iter = 0
    for epoch in range(1, args.epochs + 1):
        sys.stdout.write('Epoch {}/{} \n '.format(epoch, args.epochs))
        sys.stdout.flush()

        for batch_idx, (data, targets) in enumerate(train_loader):
            model.train()
            # Reshape data
            data, targets, angles = rotate_tensor(data.numpy(),
                                                  args.init_rot_range,
                                                  args.relative_rot_range)
            data = torch.from_numpy(data).to(device)
            targets = torch.from_numpy(targets).to(device)
            angles = torch.from_numpy(angles).to(device)
            angles = angles.view(angles.size(0), 1)

            # Forward passes
            optimizer.zero_grad()
            f_data = model(data)  # [N,2,1,1]
            f_targets = model(targets)  #[N,2,1,1]

            #Apply rotatin matrix to f_data with feature transformer
            f_data_trasformed = feature_transformer(f_data, angles, device)

            #Define loss

            loss = define_loss(args, f_data_trasformed, f_targets)

            # Backprop
            loss.backward()
            optimizer.step()

            #Log progress
            if batch_idx % args.log_interval == 0:
                sys.stdout.write(
                    'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\r'.format(
                        epoch, batch_idx * len(data),
                        len(train_loader.dataset),
                        100. * batch_idx / len(train_loader), loss))
                sys.stdout.flush()

                writer.add_scalar('Training Loss', loss, n_iter)

            #Store training and test loss
            if batch_idx % args.store_interval == 0:
                #Train Loss
                train_loss.append(
                    evaluate_model(args, model, device, train_loader_eval))

                #Rotation loss in trainign set
                mean, std = rotation_test(args, model, device,
                                          train_loader_eval)
                prediction_mean_error.append(mean)
                writer.add_scalar('Mean test error', mean, n_iter)

                prediction_error_std.append(std)

            n_iter += 1

        save_model(args, model)

    #Save model

    #Save losses
    train_loss = np.array(train_loss)
    prediction_mean_error = np.array(prediction_mean_error)
    prediction_error_std = np.array(prediction_error_std)

    np.save(path + '/training_loss', train_loss)
    np.save(path + '/prediction_mean_error', prediction_mean_error)
    np.save(path + '/prediction_error_std', prediction_error_std)

    plot_learning_curve(args, train_loss, prediction_mean_error,
                        prediction_error_std, path)

    #Get diagnostics per digit
    get_error_per_digit(args, model, device)
コード例 #10
0
ファイル: enc_train.py プロジェクト: skasai5296/IcGAN
def train(args):
    if args.use_tensorboard:
        log_name = "enc_lr{}".format(args.learning_rate)
        writer = SummaryWriter(log_dir=os.path.join('runs', log_name))

    device = torch.device('cuda' if args.enable_cuda and torch.cuda.is_available() else 'cpu')

    # transforms applied
    transform = transforms.Compose([
                        transforms.Resize((args.image_size, args.image_size)),
                        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
                        transforms.RandomHorizontalFlip(),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    # dataset and dataloader for training
    train_dataset = CelebA(args.root_dir, args.img_dir, args.ann_dir, transform=transform)
    test_dataset = CelebA(args.root_dir, args.img_dir, args.ann_dir, transform=transform, train=False)
    fsize = train_dataset.feature_size
    trainloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True, num_workers=4)
    testloader = DataLoader(test_dataset, batch_size=args.show_size, shuffle=True, collate_fn=collate_fn, drop_last=True, num_workers=4)

    '''
    dataloader returns dictionaries.
    sample : {'image' : (bs, 64, 64, 3), 'attributes' : (bs, att_size)}
    '''

    attnames = list(train_dataset.df.columns)[1:]

    # model, optimizer, criterion
    gen = Generator(in_c = args.nz + fsize)
    gen = gen.to(device)
    modeldir = os.path.join(args.model_path, "{}/res_{}".format(args.dataset, args.residual))
    MODELPATH = os.path.join(modeldir, 'gen_epoch_{}.ckpt'.format(args.model_ep))
    gen.load_state_dict(torch.load(MODELPATH))

    enc_y = Encoder(fsize).to(device)
    enc_z = Encoder(args.nz, for_y=False).to(device)

    """
    attr = AttrEncoder().to(device)
    ATTRPATH = os.path.join(args.model_path, "atteval_epoch_{}.ckpt".format(args.attr_epoch))
    attr.load_state_dict(torch.load(ATTRPATH))
    attr.eval()
    """

    # initialize weights for encoders
    enc_y.apply(init_weights)
    enc_z.apply(init_weights)

    if args.opt_method == 'Adam':
        enc_y_optim = optim.Adam(enc_y.parameters(), lr=args.learning_rate, betas=args.betas)
        enc_z_optim = optim.Adam(enc_z.parameters(), lr=args.learning_rate, betas=args.betas)
    if args.opt_method == 'SGD':
        enc_y_optim = optim.SGD(enc_y.parameters(), lr=args.learning_rate, momentum=args.momentum)
        enc_z_optim = optim.SGD(enc_z.parameters(), lr=args.learning_rate, momentum=args.momentum)
        enc_y_scheduler = optim.lr_scheduler.ReduceLROnPlateau(enc_y_optim, patience=args.patience)
        enc_z_scheduler = optim.lr_scheduler.ReduceLROnPlateau(enc_z_optim, patience=args.patience)
    criterion = nn.MSELoss()

    noise = torch.randn((args.batch_size, args.nz)).to(device)

    if args.use_tensorboard:
        writer.add_text("Text", "begin training, lr={}".format(args.learning_rate))
    print("begin training, lr={}".format(args.learning_rate), flush=True)
    stepcnt = 0
    gen.eval()
    enc_y.train()
    enc_z.train()

    for ep in range(args.num_epoch):

        YLoss = 0
        ZLoss = 0

        ittime = time.time()
        run_time = 0
        run_ittime = 0

        for it, sample in enumerate(trainloader):

            elapsed = time.time()

            x = sample['image'].to(device)
            y = sample['attributes'].to(device)

            '''training of attribute encoder'''
            # train on real images, target are real attributes
            enc_y.zero_grad()
            out = enc_y(x)
            l2loss = criterion(out, y)
            l2loss.backward()
            enc_y_optim.step()
            loss_y = l2loss.detach().cpu().item()

            '''training of identity encoder'''
            # train on fake images generated with real labels, target are original identities
            enc_z.zero_grad()
            y_sample = randomsample(train_dataset, args.batch_size).to(device)
            with torch.no_grad():
                x_fake = gen(noise, y_sample)
            z_recon = enc_z(x_fake)
            l2loss2 = criterion(z_recon, noise)
            l2loss2.backward()
            enc_z_optim.step()
            loss_z = l2loss2.detach().cpu().item()

            YLoss += loss_y
            ZLoss += loss_z

            ittime_a = time.time() - ittime
            run_time += time.time() - elapsed
            run_ittime += ittime_a

            '''log the losses and images, get time of loop'''
            if it % args.log_every == (args.log_every - 1):
                if args.use_tensorboard:
                    writer.add_scalar('y loss', loss_y, stepcnt+1)
                    writer.add_scalar('z loss', loss_z, stepcnt+1)

                after = time.time()
                print("{}th iter\ty loss: {:.5f}\tz loss: {:.5f}\t{:.4f}s per step, {:.4f}s per iter".format(it+1, loss_y, loss_z, run_time / args.log_every, run_ittime / args.log_every), flush=True)
                run_time = 0
                run_ittime = 0
            ittime = time.time()

            stepcnt += 1

        print("epoch [{}/{}] done | y loss: {:.6f} \t z loss: {:.6f}]".format(ep+1, args.num_epoch, YLoss, ZLoss), flush=True)
        if args.use_tensorboard:
            writer.add_text("epoch loss", "epoch [{}/{}] done | y loss: {:.6f} \t z loss: {:.6f}]".format(ep+1, args.num_epoch, YLoss, ZLoss), ep+1)

        savepath = os.path.join(args.model_path, "{}/res_{}".format(args.dataset, args.residual))
        try:
            torch.save(enc_y.state_dict(), os.path.join(savepath, "enc_y_epoch_{}.ckpt".format(ep+1)))
            torch.save(enc_z.state_dict(), os.path.join(savepath, "enc_z_epoch_{}.ckpt".format(ep+1)))
            print("saved encoder model at {}".format(savepath))
        except OSError:
            print("failed to save model for epoch {}".format(ep+1))

        if ep % args.recon_every == (args.recon_every - 1):
            # reconstruction and attribute transfer of images
            outpath = os.path.join(args.output_path, "{}/res_{}".format(args.dataset, args.residual))
            SAVEPATH = os.path.join(outpath, 'enc_epoch_{}'.format(ep+1))
            if not os.path.exists(SAVEPATH):
                os.mkdir(SAVEPATH)

            with torch.no_grad():
                for sample in testloader:
                    im = sample['image']
                    grid = vutils.make_grid(im, normalize=True)
                    vutils.save_image(grid, os.path.join(SAVEPATH, 'original.png'))
                    im = im.to(device)
                    y = enc_y(im)
                    z = enc_z(im)
                    im = gen(z, y)
                    """
                    y_h = attr(im)
                    """
                    recon = im.cpu()
                    grid = vutils.make_grid(recon, normalize=True)
                    vutils.save_image(grid, os.path.join(SAVEPATH, 'recon.png'))
                    break

                    """
                    CNT = 0
                    ALLCNT = 0
                    for idx in range(fsize):
                        fname = attnames[idx]
                        y_p_h = y_h.clone()
                        for i in range(args.show_size):
                            y_p_h[i, idx] = 0 if y_p_h[i, idx] == 1 else 1
                        out = gen(z, y_p_h)
                        trans = out.cpu()
                        grid2 = vutils.make_grid(trans, normalize=True)
                        vutils.save_image(grid2, os.path.join(SAVEPATH, '{}.png'.format(fname)))
                        cnt, allcnt = eval_im(attr, out, y_p_h)
                        CNT += cnt
                        ALLCNT += allcnt
                    break
            print("epoch {} for encoder, acc: {:.03}%".format(ep+1, CNT / ALLCNT * 100), flush=True)
                    """



    print("end training")
    if args.use_tensorboard:
        writer.add_text("Text", "end training")
        writer.close()