示例#1
0
class ProjectorTrainer(object):
    def __init__(self,
                 logDir,
                 printEvery=1,
                 resume=False,
                 lossRatio=0.0,
                 useTensorboard=True):
        super().__init__()

        self.printEvery = printEvery
        self.logDir = logDir
        self.lossRatio = lossRatio  # (1-a)*dissimLoss + a*realismLoss
        self.currentEpoch = 0
        self.totalBatches = 0

        self.P = Projector()
        # pre-trained G and D !
        self.G = Generator()
        self.D = Discriminator()

        # once hook is attached, activations will be pushed to self.activations
        self.D.attachLayerHook(self.D.layer3)

        self.device = torch.device('cpu')
        if torch.cuda.is_available():
            self.device = torch.device('cuda:0')

            self.G = self.G.to(self.device)
            self.D = self.D.to(self.device)
            self.P = self.P.to(self.device)

            # parallelize models on both devices, splitting input on batch dimension
            self.G = torch.nn.DataParallel(self.G, device_ids=[0, 1])
            self.D = torch.nn.DataParallel(self.D, device_ids=[0, 1])
            self.P = torch.nn.DataParallel(self.P, device_ids=[0, 1])

        self.optim = torch.optim.Adam(self.P.parameters(),
                                      lr=0.0005,
                                      betas=(0.5, 0.999))

        self.load(resume=resume)

        self.useTensorboard = useTensorboard
        self.tensorGraphInitialized = False
        self.writer = None
        if useTensorboard:
            self.writer = SummaryWriter(
                os.path.join(self.logDir, 'tensorboard'))

    ''' load generator and discriminator and (maybe) projector weights '''

    def load(self, resume=False):
        self.G.load_state_dict(
            torch.load(os.path.join(self.logDir, 'generator.pth')))
        self.D.load_state_dict(
            torch.load(os.path.join(self.logDir, 'discrim.pth')))

        if (resume):
            self.P.load_state_dict(
                torch.load(os.path.join(self.logDir, 'projector.pth')))
            self.optim.load_state_dict(
                torch.load(os.path.join(self.logDir, 'optimP.pth')))

            with open(os.path.join(self.logDir, 'recent-projector.log'),
                      'r') as f:
                runData = json.load(f)

            self.currentEpoch = runData['epoch']
            self.totalBatches = runData['totalBatches']

    def save(self):
        logTable = {
            'epoch': self.currentEpoch,
            'totalBatches': self.totalBatches
        }

        torch.save(self.P.state_dict(),
                   os.path.join(self.logDir, 'projector.pth'))
        torch.save(self.optim.state_dict(),
                   os.path.join(self.logDir, 'optimP.pth'))

        with open(os.path.join(self.logDir, 'recent-projector.log'), 'w') as f:
            f.write(json.dumps(logTable))

        tqdm.write('======== SAVED RECENT MODEL ========')

    def train(self, trainData):
        numBatches = 0
        self.P.train()

        for i, sample in enumerate(tqdm(trainData)):
            data = sample['data']

            self.optim.zero_grad()
            self.P.zero_grad()

            inputVoxels = torch.zeros(data['62'].shape[0], 64, 64,
                                      64).to(self.device)
            inputVoxels[:, 1:-1, 1:-1, 1:-1] = data['62'].to(self.device)

            #TODO: randomly drop 50% of voxels.

            z = self.P(inputVoxels)

            outputVoxels = self.G(z)

            ### Realism loss: run through the GAN, use output as "Real" vs Not
            genD = self.D(outputVoxels)
            realD = self.D(inputVoxels)
            realismLoss = -torch.mean(torch.log(realD) + torch.log(1. - genD))

            ### Dissimilarity loss

            # Fetch layer 3 activations ("We specifically
            # select the output of the 256 × 8 × 8 × 8 layer")
            acts = self.D.module.activations

            # NOTE: assumes 2 GPUs...
            actGen = torch.nn.parallel.gather(acts[:2], 'cuda:0')
            actReal = torch.nn.parallel.gather(acts[-2:], 'cuda:0')
            self.D.module.activations = []

            dissimilarityLoss = torch.mean(torch.abs(actGen - actReal))

            loss = dissimilarityLoss * (1.0 - self.lossRatio) + realismLoss * (
                self.lossRatio)

            self.P.zero_grad()
            loss.backward()
            self.optim.step()

            #log
            numBatches += 1
            if i % self.printEvery == 0:
                tqdm.write(
                    f'[TRAIN] Epoch {self.currentEpoch:03d}, Batch {i:03d}: '
                    f'Dissim Loss: {float(dissimilarityLoss.item()):2.3f}, Realism = {float(realismLoss.item()):2.3f}'
                )

                if (self.useTensorboard):
                    self.writer.add_scalar('DissimLoss/train',
                                           dissimilarityLoss,
                                           numBatches + self.totalBatches)
                    self.writer.add_scalar('RealismLoss/train', realismLoss,
                                           numBatches + self.totalBatches)
                    self.writer.add_scalar('Loss/train', loss,
                                           numBatches + self.totalBatches)
                    self.writer.flush()

                    if not self.tensorGraphInitialized:
                        self.writer.add_graph(self.P.module,
                                              torch.ones(inputVoxels.size))
                        self.writer.flush()

                        self.writer.add_graph(self.G.module,
                                              torch.ones(z.size))
                        self.writer.flush()

                        self.writer.add_graph(self.D.module,
                                              torch.ones(outputVoxels.size))
                        self.writer.flush()

                        self.tensorGraphInitialized = True

        #self.trainLoss.append(epochLoss)
        self.currentEpoch += 1
        self.totalBatches += numBatches
示例#2
0
def train(args):
    device_str = "cuda" if torch.cuda.is_available() else "cpu"
    device = torch.device(device_str)

    gen = Generator(args.nz, 800)
    gen = gen.to(device)
    gen.apply(weights_init)

    discriminator = Discriminator(800)
    discriminator = discriminator.to(device)
    discriminator.apply(weights_init)

    bce = nn.BCELoss()
    bce = bce.to(device)

    galaxy_dataset = GalaxySet(args.data_path,
                               normalized=args.normalized,
                               out=args.out)
    loader = DataLoader(galaxy_dataset,
                        batch_size=args.bs,
                        shuffle=True,
                        num_workers=2,
                        drop_last=True)
    loader_iter = iter(loader)

    d_optimizer = Adam(discriminator.parameters(),
                       betas=(0.5, 0.999),
                       lr=args.lr)
    g_optimizer = Adam(gen.parameters(), betas=(0.5, 0.999), lr=args.lr)

    real_labels = to_var(torch.ones(args.bs), device_str)
    fake_labels = to_var(torch.zeros(args.bs), device_str)
    fixed_noise = to_var(torch.randn(1, args.nz), device_str)

    for i in tqdm(range(args.iters)):
        try:
            batch_data = loader_iter.next()
        except StopIteration:
            loader_iter = iter(loader)
            batch_data = loader_iter.next()

        batch_data = to_var(batch_data, device).unsqueeze(1)

        batch_data = batch_data[:, :, :1600:2]
        batch_data = batch_data.view(-1, 800)

        ### Train Discriminator ###

        d_optimizer.zero_grad()

        # train Infer with real
        pred_real = discriminator(batch_data)
        d_loss = bce(pred_real, real_labels)

        # train infer with fakes
        z = to_var(torch.randn((args.bs, args.nz)), device)
        fakes = gen(z)
        pred_fake = discriminator(fakes.detach())
        d_loss += bce(pred_fake, fake_labels)

        d_loss.backward()

        d_optimizer.step()

        ### Train Gen ###

        g_optimizer.zero_grad()

        z = to_var(torch.randn((args.bs, args.nz)), device)
        fakes = gen(z)
        pred_fake = discriminator(fakes)
        gen_loss = bce(pred_fake, real_labels)

        gen_loss.backward()
        g_optimizer.step()

        if i % 5000 == 0:
            print("Iteration %d >> g_loss: %.4f., d_loss: %.4f." %
                  (i, gen_loss, d_loss))
            torch.save(gen.state_dict(),
                       os.path.join(args.out, 'gen_%d.pkl' % 0))
            torch.save(discriminator.state_dict(),
                       os.path.join(args.out, 'disc_%d.pkl' % 0))
            gen.eval()
            fixed_fake = gen(fixed_noise).detach().cpu().numpy()
            real_data = batch_data[0].detach().cpu().numpy()
            gen.train()
            display_noise(fixed_fake.squeeze(),
                          os.path.join(args.out, "gen_sample_%d.png" % i))
            display_noise(real_data.squeeze(),
                          os.path.join(args.out, "real_%d.png" % 0))
示例#3
0
class GAN3DTrainer(object):
    def __init__(self,
                 logDir,
                 printEvery=1,
                 resume=False,
                 useTensorboard=True):
        super(GAN3DTrainer, self).__init__()

        self.logDir = logDir

        self.currentEpoch = 0
        self.totalBatches = 0

        self.trainStats = {'lossG': [], 'lossD': [], 'accG': [], 'accD': []}

        self.printEvery = printEvery

        self.G = Generator()
        self.D = Discriminator()

        self.device = torch.device('cpu')
        if torch.cuda.is_available():
            self.device = torch.device('cuda:0')

            self.G = self.G.to(self.device)
            self.D = self.D.to(self.device)

            # parallelize models on both devices, splitting input on batch dimension
            self.G = torch.nn.DataParallel(self.G, device_ids=[0, 1])
            self.D = torch.nn.DataParallel(self.D, device_ids=[0, 1])

        # optim params direct from paper
        self.optimG = torch.optim.Adam(self.G.parameters(),
                                       lr=0.0025,
                                       betas=(0.5, 0.999))

        self.optimD = torch.optim.Adam(self.D.parameters(),
                                       lr=0.00005,
                                       betas=(0.5, 0.999))

        if resume:
            self.load()

        self.useTensorboard = useTensorboard
        self.tensorGraphInitialized = False
        self.writer = None
        if useTensorboard:
            self.writer = SummaryWriter(
                os.path.join(self.logDir, 'tensorboard'))

    def train(self, trainData: torch.utils.data.DataLoader):
        epochLoss = 0.0
        numBatches = 0

        self.G.train()
        self.D.train()

        for i, sample in enumerate(tqdm(trainData)):
            data = sample['data']

            self.optimG.zero_grad()
            self.G.zero_grad()

            self.optimD.zero_grad()
            self.D.zero_grad()

            realVoxels = torch.zeros(data['62'].shape[0], 64, 64,
                                     64).to(self.device)
            realVoxels[:, 1:-1, 1:-1, 1:-1] = data['62'].to(self.device)

            # discriminator train
            z = torch.normal(torch.zeros(data['62'].shape[0], 200),
                             torch.ones(data['62'].shape[0], 200) * 0.33).to(
                                 self.device)

            fakeVoxels = self.G(z)
            fakeD = self.D(fakeVoxels)
            realD = self.D(realVoxels)

            lossD = -torch.mean(torch.log(realD) + torch.log(1. - fakeD))
            accD = ((realD >= .5).float().mean() +
                    (fakeD < .5).float().mean()) / 2.
            accG = (fakeD > .5).float().mean()

            # only train if Disc wrong enough :)
            if accD < .8:
                self.D.zero_grad()
                lossD.backward()
                self.optimD.step()

            # gen train
            z = torch.normal(torch.zeros(data['62'].shape[0], 200),
                             torch.ones(data['62'].shape[0], 200) * 0.33).to(
                                 self.device)

            fakeVoxels = self.G(z)
            fakeD = self.D(fakeVoxels)

            # https://arxiv.org/pdf/1706.05170.pdf (IV. Methods, A. Training the gen model)
            lossG = -torch.mean(torch.log(fakeD))

            self.D.zero_grad()
            self.G.zero_grad()
            lossG.backward()
            self.optimG.step()

            #log
            numBatches += 1
            if i % self.printEvery == 0:
                tqdm.write(
                    f'[TRAIN] Epoch {self.currentEpoch:03d}, Batch {i:03d}: '
                    f'gen: {float(accG.item()):2.3f}, dis = {float(accD.item()):2.3f}'
                )

                if (self.useTensorboard):
                    self.writer.add_scalar('GenLoss/train', lossG,
                                           numBatches + self.totalBatches)
                    self.writer.add_scalar('DisLoss/train', lossD,
                                           numBatches + self.totalBatches)
                    self.writer.add_scalar('GenAcc/train', accG,
                                           numBatches + self.totalBatches)
                    self.writer.add_scalar('DisAcc/train', accD,
                                           numBatches + self.totalBatches)
                    self.writer.flush()

                    if not self.tensorGraphInitialized:
                        #TODO: why can't I push graph?
                        tempZ = torch.autograd.Variable(
                            torch.rand(data['62'].shape[0], 200, 1, 1,
                                       1)).cuda(1)
                        self.writer.add_graph(self.G.module, tempZ)
                        self.writer.flush()

                        self.writer.add_graph(self.D.module, fakeVoxels)
                        self.writer.flush()

                        self.tensorGraphInitialized = True

        #self.trainLoss.append(epochLoss)
        self.currentEpoch += 1
        self.totalBatches += numBatches

    def save(self):
        logTable = {
            'epoch': self.currentEpoch,
            'totalBatches': self.totalBatches
        }

        torch.save(self.G.state_dict(),
                   os.path.join(self.logDir, 'generator.pth'))
        torch.save(self.D.state_dict(), os.path.join(self.logDir,
                                                     'discrim.pth'))
        torch.save(self.optimG.state_dict(),
                   os.path.join(self.logDir, 'optimG.pth'))
        torch.save(self.optimD.state_dict(),
                   os.path.join(self.logDir, 'optimD.pth'))

        with open(os.path.join(self.logDir, 'recent.log'), 'w') as f:
            f.write(json.dumps(logTable))

        pickle.dump(self.trainStats,
                    open(os.path.join(self.logDir, 'trainStats.pkl'), 'wb'))

        tqdm.write('======== SAVED RECENT MODEL ========')

    def load(self):
        self.G.load_state_dict(
            torch.load(os.path.join(self.logDir, 'generator.pth')))
        self.D.load_state_dict(
            torch.load(os.path.join(self.logDir, 'discrim.pth')))
        self.optimG.load_state_dict(
            torch.load(os.path.join(self.logDir, 'optimG.pth')))
        self.optimD.load_state_dict(
            torch.load(os.path.join(self.logDir, 'optimD.pth')))

        with open(os.path.join(self.logDir, 'recent.log'), 'r') as f:
            runData = json.load(f)

        self.trainStats = pickle.load(
            open(os.path.join(self.logDir, 'trainStats.pkl'), 'rb'))

        self.currentEpoch = runData['epoch']
        self.totalBatches = runData['totalBatches']