def train(self, epochs, load=False): print('TRAINING MODEL:' ' BATCH_SIZE = ' + str(Trainer.BATCH_SIZE) + ', PARTICLE_DIM: ' + str(PARTICLE_DIM) + ', EPOCHS: ' + str(epochs) + ', PRTCL_LATENT_SPACE_SIZE: ' + str(self.PRTCL_LATENT_SPACE_SIZE)) encoder, decoder, discriminator = self.create_model() enc_optim = torch.optim.Adam(encoder.parameters(), lr=self.LR) dec_optim = torch.optim.Adam(decoder.parameters(), lr=self.LR) dis_optim = torch.optim.Adam(discriminator.parameters(), lr=self.LR) embedder = self.create_embedder() deembedder = self.create_deembedder() if load: print('LOADING MODEL STATES...') encoder.load_state_dict(torch.load(Trainer.ENCODER_SAVE_PATH)) decoder.load_state_dict(torch.load(Trainer.DECODER_SAVE_PATH)) discriminator.load_state_dict( torch.load(Trainer.DISCRIMINATOR_SAVE_PATH)) embedder.load_state_dict(torch.load(Trainer.PDG_EMBED_SAVE_PATH)) deembedder.load_state_dict( torch.load(Trainer.PDG_DEEMBED_SAVE_PATH)) print('AUTOENCODER') print(encoder) print(decoder) print(discriminator) print('EMBEDDER') print(embedder) print('DEEMBEDDER') print(deembedder) _data = load_data() data_train, data_valid = self.prep_data(_data, batch_size=Trainer.BATCH_SIZE, valid=0.1) particles = torch.tensor(particle_idxs(), device=self.device) particles.requires_grad = False for epoch in range(epochs): for n_batch, batch in enumerate(data_train): encoder.zero_grad() decoder.zero_grad() discriminator.zero_grad() real_data: torch.Tensor = batch.to(self.device) emb_data = self.embed_data(real_data, [embedder]).detach() batch_size = len(batch) zeros = torch.zeros(batch_size, device=self.device, requires_grad=False) ones = torch.ones(batch_size, device=self.device, requires_grad=False) # ======== Train Discriminator ======== # decoder.freeze(True) encoder.freeze(True) discriminator.freeze(False) lat_fake = torch.randn(batch_size, self.PRTCL_LATENT_SPACE_SIZE, device=self.device) disc_fake = discriminator(lat_fake) lat_real = encoder(emb_data) disc_real = discriminator(lat_real) loss_fake = MSELoss()(disc_fake, zeros) loss_real = MSELoss()(disc_real, ones) loss_fake.backward() loss_real.backward() dis_optim.step() # ======== Train Generator ======== # decoder.freeze(False) encoder.freeze(False) discriminator.freeze(True) lat_real = encoder(emb_data) recon_data = decoder(lat_real) d_real = discriminator(encoder(emb_data)) recon_loss = MSELoss()(emb_data, recon_data) d_loss = MSELoss()(d_real, zeros) recon_loss.backward() d_loss.backward() enc_optim.step() dec_optim.step() self.train_deembeders([ (particles, embedder, deembedder), ], epochs=2) if n_batch % 100 == 0: self.print_deemb_quality(particles, embedder, deembedder) self.show_heatmaps(emb_data[:30, :], recon_data[:30, :], reprod=False, save=True, epoch=epoch, batch=n_batch) err_kld, err_wass = self.gen_show_comp_hists( decoder, _data, attr_idxs=[ FEATURES - 8, FEATURES - 7, FEATURES - 6, FEATURES - 5 ], embedders=[embedder], emb=False, deembedder=deembedder, save=True, epoch=epoch, batch=n_batch) self.errs_kld.append(err_kld) self.errs_wass.append(err_wass) valid_loss = self._valid_loss(encoder, decoder, embedder, data_valid) print( f'Epoch: {str(epoch)}/{epochs} :: ' f'Batch: {str(n_batch)}/{str(len(data_train))} :: ' f'train loss: {"{:.6f}".format(round(recon_loss.item(), 6))} :: ' f'valid loss: {"{:.6f}".format(round(valid_loss, 6))} :: ' f'err kld: {"{:.6f}".format(round(err_kld, 6))} :: ' f'err wass: {"{:.6f}".format(round(err_wass, 6))}') self._save_models(encoder, decoder, discriminator, embedder, deembedder) with open(self.ERRS_SAVE_PATH, 'wb') as handle: pickle.dump((self.errs_kld, self.errs_wass), handle, protocol=pickle.HIGHEST_PROTOCOL) return encoder, decoder, discriminator, embedder, deembedder
def load_trans_data(self): return load_data()
from common.consts import FEATURES from common.load_event_data import sgnlog_rev from i_trainer.load_data import load_data, sgnlog from single_prtcl_generator_wgan_pytorch.trainer import Trainer ### TRAINING trainer = Trainer() #generator, discriminator, embedder, deembedder = trainer.train(epochs=15, load=False) wgan_gen, wgan_disc = trainer.create_model() wgan_gen.load_state_dict(torch.load(trainer.GENERATOR_SAVE_PATH)) embedder = trainer.create_embedder() deembedder = trainer.create_deembedder() _data = load_data() _data = _data[:10000, :] trainer.gen_show_comp_hists( wgan_gen, _data, attr_idxs=[7 - 4, 8 - 4, 9 - 4, FEATURES - 5], embedders=[embedder], emb=False, sample_cnt=1000, deembedder=deembedder, ) ''' import time
def train(self, epochs, load=False): print('TRAINING MODEL:' ' BATCH_SIZE = ' + str(self.BATCH_SIZE) + ', PARTICLE_DIM: ' + str(PARTICLE_DIM) + ', EPOCHS: ' + str(epochs) + ', PRTCL_LATENT_SPACE_SIZE: ' + str(self.PRTCL_LATENT_SPACE_SIZE)) generator, discriminator = self.create_model() gen_optim = Adam(generator.parameters(), lr=self.LR, betas=(.1, .9)) dis_optim = Adam(discriminator.parameters(), lr=self.LR, betas=(.1, .9)) embedder = self.create_embedder() deembedder = self.create_deembedder() if load: print('LOADING MODEL STATES...') try: generator.load_state_dict( torch.load(Trainer.GENERATOR_SAVE_PATH)) except Exception: print('Problem loading generator!') pass try: discriminator.load_state_dict( torch.load(Trainer.CRITIC_SAVE_PATH)) except Exception: print('Problem loading critic!') pass try: embedder.load_state_dict( torch.load(Trainer.PDG_EMBED_SAVE_PATH)) except Exception: print('Problem loading embeder!') pass try: deembedder.load_state_dict( torch.load(Trainer.PDG_DEEMBED_SAVE_PATH)) except Exception: print('Problem loading deembeder!') pass print('GENERATOR') print(generator) print('DISCRIMINATOR') print(discriminator) print('EMBEDDER') print(embedder) print('DEEMBEDDER') print(deembedder) _data = load_data() data_train, data_valid = self.prep_data(_data, batch_size=self.BATCH_SIZE, valid=0.1) particles = torch.tensor(particle_idxs(), device=self.device) particles.requires_grad = False for epoch in range(epochs): for n_batch, batch in enumerate(data_train): real_data: torch.Tensor = batch.to(self.device) emb_data = self.embed_data(real_data, [embedder]).detach() batch_size = len(batch) # ======== Train Generator ======== # gen_optim.zero_grad() for p in discriminator.parameters(): p.requires_grad = False # to avoid computation # Sample noise as generator input lat_fake = torch.randn(batch_size, self.PRTCL_LATENT_SPACE_SIZE, device=self.device) gen_data = generator(lat_fake) output = discriminator(gen_data) gen_loss = -torch.mean(output) gen_loss.backward() gen_optim.step() for p in discriminator.parameters(): p.requires_grad = True # to avoid computation # ======== Train Discriminator ======== # for _ in range(self.CRITIC_ITERATIONS): critic_real = discriminator(emb_data) critic_fake = discriminator(gen_data.detach()) gp = self.calc_gradient_penalty(discriminator, emb_data, gen_data, batch_size) critic_loss = -(torch.mean(critic_real) - torch.mean(critic_fake)) + gp critic_loss.backward() dis_optim.step() dis_optim.zero_grad() self.train_deembeders([ (particles, embedder, deembedder), ], epochs=2) if n_batch % 100 == 0: self.print_deemb_quality(particles, embedder, deembedder) self.show_heatmaps(emb_data[:30, :], gen_data[:30, :], reprod=False, save=True, epoch=epoch, batch=n_batch) err_kld, err_wass = self.gen_show_comp_hists( generator, _data, attr_idxs=[ FEATURES - 8, FEATURES - 7, FEATURES - 6, FEATURES - 5 ], embedders=[embedder], emb=False, deembedder=deembedder, save=True, epoch=epoch, batch=n_batch) self.errs_kld.append(err_kld) self.errs_wass.append(err_wass) print( f'Epoch: {str(epoch)}/{epochs} :: ' f'Batch: {str(n_batch)}/{str(len(data_train))} :: ' f'generator loss: {"{:.6f}".format(round(gen_loss.item(), 6))} :: ' f'critic loss: {"{:.6f}".format(round(critic_loss.item(), 6))} :: ' f'err kld: {"{:.6f}".format(round(err_kld, 6))} :: ' f'err wass: {"{:.6f}".format(round(err_wass, 6))}') self._save_models(generator, discriminator, embedder, deembedder) with open(self.ERRS_SAVE_PATH, 'wb') as handle: pickle.dump((self.errs_kld, self.errs_wass), handle, protocol=pickle.HIGHEST_PROTOCOL) return generator, discriminator, embedder, deembedder
def train(self, epochs, load=False): print('TRAINING MODEL:' ' BATCH_SIZE = ' + str(self.BATCH_SIZE) + ', PARTICLE_DIM: ' + str(PARTICLE_DIM) + ', EPOCHS: ' + str(epochs) + ', PRTCL_LATENT_SPACE_SIZE: ' + str(self.PRTCL_LATENT_SPACE_SIZE)) generator, discriminator = self.create_model() gen_optim = torch.optim.Adam(generator.parameters(), lr=self.LR, betas=(0, .9)) dis_optim = torch.optim.Adam(discriminator.parameters(), lr=self.LR, betas=(0, .9)) embedder = self.create_embedder() deembedder = self.create_deembedder() particles = torch.tensor(particle_idxs(), device=self.device) if load: print('LOADING MODEL STATES...') generator.load_state_dict(torch.load(Trainer.GENERATOR_SAVE_PATH)) discriminator.load_state_dict( torch.load(Trainer.DISCRIMINATOR_SAVE_PATH)) embedder.load_state_dict(torch.load(Trainer.PDG_EMBED_SAVE_PATH)) deembedder.load_state_dict( torch.load(Trainer.PDG_DEEMBED_SAVE_PATH)) print('GENERATOR') print(generator) print('DISCRIMINATOR') print(discriminator) print('EMBEDDER') print(embedder) print('DEEMBEDDER') print(deembedder) _data = load_data() data_train, data_valid = self.prep_data(_data, batch_size=self.BATCH_SIZE, valid=0.1) for epoch in range(epochs): for n_batch, batch in enumerate(data_train): real_data: torch.Tensor = batch.to(self.device) emb_data = self.embed_data(real_data, [embedder]).detach() batch_size = len(batch) valid = torch.ones(batch_size, device=self.device, requires_grad=False) fake = torch.zeros(batch_size, device=self.device, requires_grad=False) # ======== Train Generator ======== # gen_optim.zero_grad() # Sample noise as generator input lat_fake = torch.randn(batch_size, self.PRTCL_LATENT_SPACE_SIZE, device=self.device) lat_fake = Variable( torch.tensor(np.random.normal( 0, 1, (batch_size, self.PRTCL_LATENT_SPACE_SIZE)), device=self.device).float()) # Generate a batch of images gen_data = generator(lat_fake) # Loss measures generator's ability to fool the discriminator g_loss = MSELoss()(discriminator(gen_data), valid) g_loss.backward() gen_optim.step() # ======== Train Discriminator ======== # dis_optim.zero_grad() real_loss = MSELoss()(discriminator(emb_data), valid) fake_loss = MSELoss()(discriminator(gen_data.detach()), fake) d_loss = (real_loss + fake_loss) / 2 d_loss.backward() dis_optim.step() self.train_deembeders([ (particles, embedder, deembedder), ], epochs=2) if n_batch % 100 == 0: self.print_deemb_quality( torch.tensor(particle_idxs(), device=self.device), embedder, deembedder) self.show_heatmaps(emb_data[:30, :], gen_data[:30, :], reprod=False, save=True, epoch=epoch, batch=n_batch) err_kld, err_wass = self.gen_show_comp_hists( generator, _data, attr_idxs=[ FEATURES - 8, FEATURES - 7, FEATURES - 6, FEATURES - 5 ], embedders=[embedder], emb=False, deembedder=deembedder, save=True, epoch=epoch, batch=n_batch) self.errs_kld.append(err_kld) self.errs_wass.append(err_wass) print( f'Epoch: {str(epoch)}/{epochs} :: ' f'Batch: {str(n_batch)}/{str(len(data_train))} :: ' f'generator loss: {"{:.6f}".format(round(g_loss.item(), 6))} :: ' f'discriminator loss: {"{:.6f}".format(round(d_loss.item(), 6))} :: ' f'err kld: {"{:.6f}".format(round(err_kld, 6))} :: ' f'err wass: {"{:.6f}".format(round(err_wass, 6))}') self._save_models(generator, discriminator, embedder, deembedder) with open(self.ERRS_SAVE_PATH, 'wb') as handle: pickle.dump((self.errs_kld, self.errs_wass), handle, protocol=pickle.HIGHEST_PROTOCOL) return generator, discriminator, embedder, deembedder