def train(self, train_set, batch_size, num_epoche, g_eta, d_eta, show=True): print('Start training | batch_size:{a} | eta:{b}'.format(a=batch_size, b=g_eta)) global origin self.to(self.device) self.in_cpu = False train_set = torch.tensor(train_set, dtype=torch.float, device=self.device) g_optimizer = optim.Adam(self.generator.parameters(), lr=g_eta) d_optimizer = optim.Adam(self.discriminator.parameters(), lr=d_eta) g_target = torch.ones(batch_size, 1).to(self.device) d_target = torch.cat( [torch.zeros(batch_size, 1), torch.ones(batch_size, 1)], 0).to(self.device) N = train_set.size()[0] N = N - N % batch_size best_asds_score = 0 best_cca_score = 0 best_scca_score = 0 for epoch in range(num_epoche): tic = time.time() perm = torch.randperm(N) steps = 0 for i in range(0, N, batch_size): # optimize generator g_optimizer.zero_grad() noise = self.noise(batch_size) g_out = self.generator(noise) gd_out = self.discriminator(g_out) g_loss = self.discriminator.criterion(gd_out, g_target) g_loss.backward() g_optimizer.step() # optimize discriminator d_optimizer.zero_grad() indices = perm[i:i + batch_size] d_input = torch.cat([g_out.detach(), train_set[indices]], 0) d_out = self.discriminator.forward(d_input) d_loss = self.discriminator.criterion(d_out, d_target) d_loss.backward() if g_loss < 3 * d_loss: d_optimizer.step() steps += 1 if show: if steps % 100 == 0: # record training losses g_r = float(g_loss.detach().cpu()) d_r = float(d_loss.detach().cpu()) self.losses.append([g_r, d_r]) # record model score with torch.no_grad(): fake = self.generate(1).cpu() # Average spectra diff score # fake = fft_data([fake]) example = self.example() # score_asds = average_spectra_diff_score(fake[0]) # try: # score_scca = average_spectra_cca_score(fake[0]) # except: # score_scca = 0 # try: # score_cca = average_cca_score(example, random.choices(origin, k=10)) # except: # score_cca = 0 self.scores.append([]) report(loss_title='Training Loss Curve', losses=self.losses, loss_labels=['Generator', 'Discriminator'], score_title='Model Score Curve', scores=self.scores, score_labels=[], interval=100, example=example) dt = time.time() - tic print('epoch ' + str(epoch) + 'finished! Time usage: ' + str(dt)) # if show is True: # with torch.no_grad(): # y = self.generate(1).to(torch.device('cpu')) # y = iflatten_complex_data(y) # diff = average_spectra_diff(y[0]) # print('The Spectra Difference: ' + str(diff)) self.to(torch.device('cpu')) self.in_cpu = True last_path = os.path.join( config.DATA_PATH, 'Trained_Models', 'Simple_GAN', time_stamp() + '|LAST' + '|BC:' + str(batch_size) + '|g_eta:' + str(g_eta) + '|d_eta:' + str(d_eta)) torch.save(self.state_dict(), last_path) # Store the model with best ASD score try: self.load_state_dict(torch.load('BEST_ASDS')) os.remove('BEST_ASDS') best_asd_path = os.path.join( config.DATA_PATH, 'Trained_Models', 'Simple_GAN', time_stamp() + '|ASDS' + '|BC:' + str(batch_size) + '|g_eta:' + str(g_eta) + '|d_eta:' + str(d_eta)) torch.save(self.state_dict(), best_asd_path) except: pass # Store the model with best CCA score try: self.load_state_dict(torch.load('BEST_CCA')) os.remove('BEST_CCA') best_cca_path = os.path.join( config.DATA_PATH, 'Trained_Models', 'Simple_GAN', time_stamp() + '|CCA' + '|BC:' + str(batch_size) + '|g_eta:' + str(g_eta) + '|d_eta:' + str(d_eta)) torch.save(self.state_dict(), best_cca_path) except: pass # Store the model with best SCCA score try: self.load_state_dict(torch.load('BEST_SCCA')) os.remove('BEST_SCCA') best_scca_path = os.path.join( config.DATA_PATH, 'Trained_Models', 'Simple_GAN', time_stamp() + '|SCCA' + '|BC:' + str(batch_size) + '|g_eta:' + str(g_eta) + '|d_eta:' + str(d_eta)) torch.save(self.state_dict(), best_scca_path) except: pass return
def train(self, train_set, batch_size, num_epoche, g_eta, d_eta, n_critic, clip_value, show=True): print('Start training | batch_size:{a} | eta:{b}'.format(a=batch_size, b=g_eta)) global origin self.to(self.device) self.in_cpu = False train_set = torch.tensor(train_set, dtype=torch.float, device=self.device) g_optimizer = optim.RMSprop(self.generator.parameters(), lr=g_eta) d_optimizer = optim.RMSprop(self.discriminator.parameters(), lr=d_eta) # g_target = torch.ones(batch_size, 1).to(self.device) # d_target = torch.cat([torch.zeros(batch_size, 1), torch.ones(batch_size, 1)], 0).to(self.device) N = train_set.size()[0] N = N - N % batch_size best_ws_dist = sys.float_info.max for epoch in range(num_epoche): tic = time.time() perm = torch.randperm(N) steps = 0 for i in range(0, N, batch_size): # optimize discriminator d_optimizer.zero_grad() indices = perm[i:i + batch_size] fake = self.generate(batch_size).detach() real = train_set[indices] # The critic loss, also the negative Wassertein distance estimate # The discriminator/critic tries to maxmize the wassertein distance d_loss = -torch.mean(self.discriminator(real)) + torch.mean( self.discriminator(fake)) d_loss.backward() d_optimizer.step() # Clip weights of discriminator for p in self.discriminator.parameters(): p.data.clamp_(-clip_value, clip_value) # optimize generator every n_critic iterations if i % n_critic == 0: g_optimizer.zero_grad() fake = self.generate(batch_size) g_loss = -torch.mean(self.discriminator(fake)) g_loss.backward() g_optimizer.step() steps += 1 if show: if steps % 100 == 0: # record training losses g_r = float(g_loss.detach().cpu()) d_r = float(d_loss.detach().cpu()) self.losses.append([g_r, d_r]) # record model score ws_dist = -d_r self.scores.append([ws_dist]) if ws_dist < best_ws_dist: best_ws_dist = ws_dist if epoch > 0: torch.save(self.state_dict(), 'BEST_WS') report(loss_title='Training loss curve', losses=self.losses, loss_labels=['Generator', 'Discriminator'], score_title='Model score curve', scores=self.scores, score_labels=['Wasserstein estimate'], interval=100, example=self.example()) dt = time.time() - tic print('epoch ' + str(epoch) + 'finished! Time usage: ' + str(dt)) # if show is True: # with torch.no_grad(): # y = self.generate(1).to(torch.device('cpu')) # y = iflatten_complex_data(y) # diff = average_spectra_diff(y[0]) # print('The Spectra Difference: ' + str(diff)) self.to(torch.device('cpu')) self.in_cpu = True last_path = os.path.join( config.DATA_PATH, 'Trained_Models', 'Complex_Fully_Connected_WGAN_LPF_W', time_stamp() + '|LAST' + '|BC:' + str(batch_size) + '|g_eta:' + str(g_eta) + '|d_eta:' + str(d_eta) + '|n_critic:' + str(n_critic) + '|clip_value:' + str(clip_value)) torch.save(self.state_dict(), last_path) # Store the model with lest Wasserstein estimate try: self.load_state_dict(torch.load('BEST_WS')) os.remove('BEST_WS') best_asd_path = os.path.join( config.DATA_PATH, 'Trained_Models', 'Complex_Fully_Connected_WGAN_LPF_W', time_stamp() + '|WS' + '|BC:' + str(batch_size) + '|g_eta:' + str(g_eta) + '|d_eta:' + str(d_eta) + '|n_critic:' + str(n_critic) + '|clip_value:' + str(clip_value)) torch.save(self.state_dict(), best_asd_path) except: pass return
def train(self, train_set, batch_size, num_epoche, e_eta, d_eta, show=True): print('Start training | batch_size:{a} | e_eta:{b} | d_eta:{c}'.format( a=batch_size, b=e_eta, c=d_eta)) self.to(self.device) self.in_cpu = False train_set = torch.tensor(train_set, dtype=torch.float, device=self.device) e_optimizer = optim.Adam(self.encoder.parameters(), lr=e_eta) d_optimizer = optim.Adam(self.decoder.parameters(), lr=d_eta) mse = nn.MSELoss(reduction='mean') N = train_set.size()[0] N = N - N % batch_size lowest_loss = sys.float_info.max for epoch in range(num_epoche): tic = time.time() perm = torch.randperm(N) steps = 0 for i in range(0, N, batch_size): e_optimizer.zero_grad() d_optimizer.zero_grad() indices = perm[i:i + batch_size] real = train_set[indices] fake = self.forward(real) loss = mse(real, fake) loss.backward() e_optimizer.step() d_optimizer.step() steps += 1 if show: if steps % 100 == 0: # Record training losses loss = float(loss.detach().cpu()) self.losses.append([loss]) self.scores.append([loss]) if loss < lowest_loss: lowest_loss = loss if epoch > 10: torch.save(self.state_dict(), 'BEST_MSE') report(loss_title='Training loss curve', losses=self.losses, loss_labels=['MSE Loss'], score_title='Model score curve', scores=self.scores, score_labels=['MSE Loss'], interval=100, example=self.example()) dt = time.time() - tic print('epoch ' + str(epoch) + '\tfinished! Time usage: ' + str(dt) + '\t Loss: ' + str(loss)) self.to(torch.device('cpu')) self.in_cpu = True last_path = os.path.join( config.DATA_PATH, 'Trained_Models', 'Autoencoder', time_stamp() + '|LAST' + '|BC:' + str(batch_size) + '|e_eta:' + str(e_eta) + '|d_eta:' + str(d_eta)) torch.save(self.state_dict(), last_path) # Store the model with lowest loss try: self.load_state_dict(torch.load('BEST_MSE')) os.remove('BEST_MSE') best_asd_path = os.path.join( config.DATA_PATH, 'Trained_Models', 'Autoencoder', time_stamp() + '|MSE' + '|BC:' + str(batch_size) + '|e_eta:' + str(e_eta) + '|d_eta:' + str(d_eta)) torch.save(self.state_dict(), best_asd_path) except: pass return