def train(self, epochs, load=False):
        print('TRAINING MODEL:'
              ' BATCH_SIZE = ' + str(Trainer.BATCH_SIZE) + ', PARTICLE_DIM: ' +
              str(PARTICLE_DIM) + ', EPOCHS: ' + str(epochs) +
              ', PRTCL_LATENT_SPACE_SIZE: ' +
              str(self.PRTCL_LATENT_SPACE_SIZE))

        encoder, decoder, discriminator = self.create_model()
        enc_optim = torch.optim.Adam(encoder.parameters(), lr=self.LR)
        dec_optim = torch.optim.Adam(decoder.parameters(), lr=self.LR)
        dis_optim = torch.optim.Adam(discriminator.parameters(), lr=self.LR)

        embedder = self.create_embedder()
        deembedder = self.create_deembedder()

        if load:
            print('LOADING MODEL STATES...')
            encoder.load_state_dict(torch.load(Trainer.ENCODER_SAVE_PATH))
            decoder.load_state_dict(torch.load(Trainer.DECODER_SAVE_PATH))
            discriminator.load_state_dict(
                torch.load(Trainer.DISCRIMINATOR_SAVE_PATH))
            embedder.load_state_dict(torch.load(Trainer.PDG_EMBED_SAVE_PATH))
            deembedder.load_state_dict(
                torch.load(Trainer.PDG_DEEMBED_SAVE_PATH))

        print('AUTOENCODER')
        print(encoder)
        print(decoder)
        print(discriminator)
        print('EMBEDDER')
        print(embedder)
        print('DEEMBEDDER')
        print(deembedder)

        _data = load_data()
        data_train, data_valid = self.prep_data(_data,
                                                batch_size=Trainer.BATCH_SIZE,
                                                valid=0.1)

        particles = torch.tensor(particle_idxs(), device=self.device)
        particles.requires_grad = False

        for epoch in range(epochs):

            for n_batch, batch in enumerate(data_train):
                encoder.zero_grad()
                decoder.zero_grad()
                discriminator.zero_grad()

                real_data: torch.Tensor = batch.to(self.device)
                emb_data = self.embed_data(real_data, [embedder]).detach()

                batch_size = len(batch)

                zeros = torch.zeros(batch_size,
                                    device=self.device,
                                    requires_grad=False)
                ones = torch.ones(batch_size,
                                  device=self.device,
                                  requires_grad=False)

                # ======== Train Discriminator ======== #
                decoder.freeze(True)
                encoder.freeze(True)
                discriminator.freeze(False)

                lat_fake = torch.randn(batch_size,
                                       self.PRTCL_LATENT_SPACE_SIZE,
                                       device=self.device)
                disc_fake = discriminator(lat_fake)

                lat_real = encoder(emb_data)
                disc_real = discriminator(lat_real)

                loss_fake = MSELoss()(disc_fake, zeros)
                loss_real = MSELoss()(disc_real, ones)

                loss_fake.backward()
                loss_real.backward()

                dis_optim.step()

                # ======== Train Generator ======== #
                decoder.freeze(False)
                encoder.freeze(False)
                discriminator.freeze(True)

                lat_real = encoder(emb_data)
                recon_data = decoder(lat_real)
                d_real = discriminator(encoder(emb_data))

                recon_loss = MSELoss()(emb_data, recon_data)
                d_loss = MSELoss()(d_real, zeros)

                recon_loss.backward()
                d_loss.backward()

                enc_optim.step()
                dec_optim.step()

                self.train_deembeders([
                    (particles, embedder, deembedder),
                ],
                                      epochs=2)

                if n_batch % 100 == 0:
                    self.print_deemb_quality(particles, embedder, deembedder)

                    self.show_heatmaps(emb_data[:30, :],
                                       recon_data[:30, :],
                                       reprod=False,
                                       save=True,
                                       epoch=epoch,
                                       batch=n_batch)
                    err_kld, err_wass = self.gen_show_comp_hists(
                        decoder,
                        _data,
                        attr_idxs=[
                            FEATURES - 8, FEATURES - 7, FEATURES - 6,
                            FEATURES - 5
                        ],
                        embedders=[embedder],
                        emb=False,
                        deembedder=deembedder,
                        save=True,
                        epoch=epoch,
                        batch=n_batch)

                    self.errs_kld.append(err_kld)
                    self.errs_wass.append(err_wass)

                    valid_loss = self._valid_loss(encoder, decoder, embedder,
                                                  data_valid)

                    print(
                        f'Epoch: {str(epoch)}/{epochs} :: '
                        f'Batch: {str(n_batch)}/{str(len(data_train))} :: '
                        f'train loss: {"{:.6f}".format(round(recon_loss.item(), 6))} :: '
                        f'valid loss: {"{:.6f}".format(round(valid_loss, 6))} :: '
                        f'err kld: {"{:.6f}".format(round(err_kld, 6))} :: '
                        f'err wass: {"{:.6f}".format(round(err_wass, 6))}')

            self._save_models(encoder, decoder, discriminator, embedder,
                              deembedder)

            with open(self.ERRS_SAVE_PATH, 'wb') as handle:
                pickle.dump((self.errs_kld, self.errs_wass),
                            handle,
                            protocol=pickle.HIGHEST_PROTOCOL)

        return encoder, decoder, discriminator, embedder, deembedder
 def load_trans_data(self):
     return load_data()
from common.consts import FEATURES
from common.load_event_data import sgnlog_rev
from i_trainer.load_data import load_data, sgnlog
from single_prtcl_generator_wgan_pytorch.trainer import Trainer

### TRAINING
trainer = Trainer()
#generator, discriminator, embedder, deembedder = trainer.train(epochs=15, load=False)
wgan_gen, wgan_disc = trainer.create_model()
wgan_gen.load_state_dict(torch.load(trainer.GENERATOR_SAVE_PATH))

embedder = trainer.create_embedder()
deembedder = trainer.create_deembedder()

_data = load_data()

_data = _data[:10000, :]

trainer.gen_show_comp_hists(
    wgan_gen,
    _data,
    attr_idxs=[7 - 4, 8 - 4, 9 - 4, FEATURES - 5],
    embedders=[embedder],
    emb=False,
    sample_cnt=1000,
    deembedder=deembedder,
)
'''
import time
예제 #4
0
    def train(self, epochs, load=False):
        print('TRAINING MODEL:'
              ' BATCH_SIZE = ' + str(self.BATCH_SIZE) + ', PARTICLE_DIM: ' +
              str(PARTICLE_DIM) + ', EPOCHS: ' + str(epochs) +
              ', PRTCL_LATENT_SPACE_SIZE: ' +
              str(self.PRTCL_LATENT_SPACE_SIZE))

        generator, discriminator = self.create_model()
        gen_optim = Adam(generator.parameters(), lr=self.LR, betas=(.1, .9))
        dis_optim = Adam(discriminator.parameters(),
                         lr=self.LR,
                         betas=(.1, .9))

        embedder = self.create_embedder()
        deembedder = self.create_deembedder()

        if load:
            print('LOADING MODEL STATES...')
            try:
                generator.load_state_dict(
                    torch.load(Trainer.GENERATOR_SAVE_PATH))
            except Exception:
                print('Problem loading generator!')
                pass

            try:
                discriminator.load_state_dict(
                    torch.load(Trainer.CRITIC_SAVE_PATH))
            except Exception:
                print('Problem loading critic!')
                pass

            try:
                embedder.load_state_dict(
                    torch.load(Trainer.PDG_EMBED_SAVE_PATH))
            except Exception:
                print('Problem loading embeder!')
                pass

            try:
                deembedder.load_state_dict(
                    torch.load(Trainer.PDG_DEEMBED_SAVE_PATH))
            except Exception:
                print('Problem loading deembeder!')
                pass

        print('GENERATOR')
        print(generator)
        print('DISCRIMINATOR')
        print(discriminator)
        print('EMBEDDER')
        print(embedder)
        print('DEEMBEDDER')
        print(deembedder)

        _data = load_data()
        data_train, data_valid = self.prep_data(_data,
                                                batch_size=self.BATCH_SIZE,
                                                valid=0.1)

        particles = torch.tensor(particle_idxs(), device=self.device)
        particles.requires_grad = False

        for epoch in range(epochs):

            for n_batch, batch in enumerate(data_train):

                real_data: torch.Tensor = batch.to(self.device)
                emb_data = self.embed_data(real_data, [embedder]).detach()

                batch_size = len(batch)

                # ======== Train Generator ======== #

                gen_optim.zero_grad()

                for p in discriminator.parameters():
                    p.requires_grad = False  # to avoid computation

                # Sample noise as generator input
                lat_fake = torch.randn(batch_size,
                                       self.PRTCL_LATENT_SPACE_SIZE,
                                       device=self.device)
                gen_data = generator(lat_fake)

                output = discriminator(gen_data)
                gen_loss = -torch.mean(output)

                gen_loss.backward()
                gen_optim.step()

                for p in discriminator.parameters():
                    p.requires_grad = True  # to avoid computation

                # ======== Train Discriminator ======== #
                for _ in range(self.CRITIC_ITERATIONS):
                    critic_real = discriminator(emb_data)
                    critic_fake = discriminator(gen_data.detach())
                    gp = self.calc_gradient_penalty(discriminator, emb_data,
                                                    gen_data, batch_size)
                    critic_loss = -(torch.mean(critic_real) -
                                    torch.mean(critic_fake)) + gp

                    critic_loss.backward()
                    dis_optim.step()

                    dis_optim.zero_grad()

                self.train_deembeders([
                    (particles, embedder, deembedder),
                ],
                                      epochs=2)

                if n_batch % 100 == 0:
                    self.print_deemb_quality(particles, embedder, deembedder)

                    self.show_heatmaps(emb_data[:30, :],
                                       gen_data[:30, :],
                                       reprod=False,
                                       save=True,
                                       epoch=epoch,
                                       batch=n_batch)
                    err_kld, err_wass = self.gen_show_comp_hists(
                        generator,
                        _data,
                        attr_idxs=[
                            FEATURES - 8, FEATURES - 7, FEATURES - 6,
                            FEATURES - 5
                        ],
                        embedders=[embedder],
                        emb=False,
                        deembedder=deembedder,
                        save=True,
                        epoch=epoch,
                        batch=n_batch)

                    self.errs_kld.append(err_kld)
                    self.errs_wass.append(err_wass)

                    print(
                        f'Epoch: {str(epoch)}/{epochs} :: '
                        f'Batch: {str(n_batch)}/{str(len(data_train))} :: '
                        f'generator loss: {"{:.6f}".format(round(gen_loss.item(), 6))} :: '
                        f'critic loss: {"{:.6f}".format(round(critic_loss.item(), 6))} :: '
                        f'err kld: {"{:.6f}".format(round(err_kld, 6))} :: '
                        f'err wass: {"{:.6f}".format(round(err_wass, 6))}')

            self._save_models(generator, discriminator, embedder, deembedder)

            with open(self.ERRS_SAVE_PATH, 'wb') as handle:
                pickle.dump((self.errs_kld, self.errs_wass),
                            handle,
                            protocol=pickle.HIGHEST_PROTOCOL)

        return generator, discriminator, embedder, deembedder
예제 #5
0
    def train(self, epochs, load=False):
        print('TRAINING MODEL:'
              ' BATCH_SIZE = ' + str(self.BATCH_SIZE) + ', PARTICLE_DIM: ' +
              str(PARTICLE_DIM) + ', EPOCHS: ' + str(epochs) +
              ', PRTCL_LATENT_SPACE_SIZE: ' +
              str(self.PRTCL_LATENT_SPACE_SIZE))

        generator, discriminator = self.create_model()
        gen_optim = torch.optim.Adam(generator.parameters(),
                                     lr=self.LR,
                                     betas=(0, .9))
        dis_optim = torch.optim.Adam(discriminator.parameters(),
                                     lr=self.LR,
                                     betas=(0, .9))

        embedder = self.create_embedder()
        deembedder = self.create_deembedder()

        particles = torch.tensor(particle_idxs(), device=self.device)

        if load:
            print('LOADING MODEL STATES...')
            generator.load_state_dict(torch.load(Trainer.GENERATOR_SAVE_PATH))
            discriminator.load_state_dict(
                torch.load(Trainer.DISCRIMINATOR_SAVE_PATH))
            embedder.load_state_dict(torch.load(Trainer.PDG_EMBED_SAVE_PATH))
            deembedder.load_state_dict(
                torch.load(Trainer.PDG_DEEMBED_SAVE_PATH))

        print('GENERATOR')
        print(generator)
        print('DISCRIMINATOR')
        print(discriminator)
        print('EMBEDDER')
        print(embedder)
        print('DEEMBEDDER')
        print(deembedder)

        _data = load_data()
        data_train, data_valid = self.prep_data(_data,
                                                batch_size=self.BATCH_SIZE,
                                                valid=0.1)

        for epoch in range(epochs):

            for n_batch, batch in enumerate(data_train):

                real_data: torch.Tensor = batch.to(self.device)
                emb_data = self.embed_data(real_data, [embedder]).detach()

                batch_size = len(batch)

                valid = torch.ones(batch_size,
                                   device=self.device,
                                   requires_grad=False)
                fake = torch.zeros(batch_size,
                                   device=self.device,
                                   requires_grad=False)

                # ======== Train Generator ======== #
                gen_optim.zero_grad()

                # Sample noise as generator input
                lat_fake = torch.randn(batch_size,
                                       self.PRTCL_LATENT_SPACE_SIZE,
                                       device=self.device)
                lat_fake = Variable(
                    torch.tensor(np.random.normal(
                        0, 1, (batch_size, self.PRTCL_LATENT_SPACE_SIZE)),
                                 device=self.device).float())
                # Generate a batch of images
                gen_data = generator(lat_fake)

                # Loss measures generator's ability to fool the discriminator
                g_loss = MSELoss()(discriminator(gen_data), valid)

                g_loss.backward()
                gen_optim.step()

                # ======== Train Discriminator ======== #
                dis_optim.zero_grad()

                real_loss = MSELoss()(discriminator(emb_data), valid)
                fake_loss = MSELoss()(discriminator(gen_data.detach()), fake)
                d_loss = (real_loss + fake_loss) / 2

                d_loss.backward()
                dis_optim.step()

                self.train_deembeders([
                    (particles, embedder, deembedder),
                ],
                                      epochs=2)

                if n_batch % 100 == 0:
                    self.print_deemb_quality(
                        torch.tensor(particle_idxs(), device=self.device),
                        embedder, deembedder)

                    self.show_heatmaps(emb_data[:30, :],
                                       gen_data[:30, :],
                                       reprod=False,
                                       save=True,
                                       epoch=epoch,
                                       batch=n_batch)
                    err_kld, err_wass = self.gen_show_comp_hists(
                        generator,
                        _data,
                        attr_idxs=[
                            FEATURES - 8, FEATURES - 7, FEATURES - 6,
                            FEATURES - 5
                        ],
                        embedders=[embedder],
                        emb=False,
                        deembedder=deembedder,
                        save=True,
                        epoch=epoch,
                        batch=n_batch)

                    self.errs_kld.append(err_kld)
                    self.errs_wass.append(err_wass)

                    print(
                        f'Epoch: {str(epoch)}/{epochs} :: '
                        f'Batch: {str(n_batch)}/{str(len(data_train))} :: '
                        f'generator loss: {"{:.6f}".format(round(g_loss.item(), 6))} :: '
                        f'discriminator loss: {"{:.6f}".format(round(d_loss.item(), 6))} :: '
                        f'err kld: {"{:.6f}".format(round(err_kld, 6))} :: '
                        f'err wass: {"{:.6f}".format(round(err_wass, 6))}')

            self._save_models(generator, discriminator, embedder, deembedder)

            with open(self.ERRS_SAVE_PATH, 'wb') as handle:
                pickle.dump((self.errs_kld, self.errs_wass),
                            handle,
                            protocol=pickle.HIGHEST_PROTOCOL)

        return generator, discriminator, embedder, deembedder