class Runner: def __init__(self, device): self.device = device self.data_path, self.sample_path, self.ts = "", "", "" self.train_data, self.test_data = None, None self.train_loader, self.test_loader = None, None self.input_dim, self.num_label = 0, 0 self.num_samples = 0 self.num_batch = 0 self.anneal_param = 0 self.model, self.opt = None, None self.reconst_loss_x, self.reconst_loss_w = None, None self.batch_size = 0 self.train_loss, self.eval_loss = [], [] def get_data(self, data_path): self.data_path = data_path train_path = data_path + '/test.csv' # train_path = data_path + '/train.csv' test_path = data_path + '/test.csv' self.train_data = Dataset(train_path) self.test_data = Dataset(test_path) self.input_dim = self.train_data.x_dim self.num_label = self.train_data.num_label self.num_samples = self.train_data.__len__() def set_save_dir(self, sample_path, ts): self.ts = ts self.sample_path = sample_path if not os.path.exists(sample_path): os.mkdir(sample_path) def train(self, model, optim, num_epoch, batch_size, learning_rate, save_samples=True, save_reconstructions=True): self.model = model self.model.train() self.batch_size = batch_size self.train_loader = DataLoader(dataset=self.train_data, batch_size=self.batch_size, shuffle=True) self.num_batch = len(self.train_loader) self.opt = self.set_opt(optim, learning_rate) self.reconst_loss_x, self.reconst_loss_w = self.set_reconst_loss() self.set_weight(num_epoch, self.num_batch) weight = 1 for epoch in range(num_epoch): rx_loss, rw_loss, kl_loss, tot_loss = 0, 0, 0, 0 tic = time.time() for i, (x, w, l) in enumerate(self.train_loader): x = x.to(device=self.device, dtype=torch.float).view(-1, self.input_dim) w = w.to(device=self.device, dtype=torch.float).view(-1, self.num_label) in_put = {'x': x, 'w': w} output, mean, log_var, z_sample = self.model(in_put) loss_kl = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp()) loss_x = self.reconst_loss_x(output['x'], x) loss = loss_x + loss_kl loss_w = 0 if 'JMVAE' == self.model.whoami: loss_w = self.reconst_loss_w(output['w'], l) weight = min(1, self.get_weight(epoch, i)) loss = (loss_x + loss_w) + weight * loss_kl loss = self.num_samples * loss / batch_size self.opt.zero_grad() loss.backward() self.opt.step() rx_loss += loss_x rw_loss += loss_w kl_loss += loss_kl tot_loss += loss if i+1 == self.num_batch: print( "Epoch[{}/{}], Loss: {:.4f}, KL Div: {:.4f}, X reconst Loss: {:.4f}, W reconst Loss: {:.4f}, Annealing Param: {:4f}, Time: {:4f}".format( epoch + 1, num_epoch, tot_loss / self.num_batch, kl_loss / self.num_batch, rx_loss / self.num_batch, rw_loss / self.num_batch, weight, time.time()-tic)) with torch.no_grad(): if save_samples: self.save_s(epoch) if save_reconstructions: self.save_r(in_put, epoch) self.train_loss.append(tot_loss/self.num_batch) # def eval(self): # # test_loader = DataLoader(dataset=test, batch_size=1, shuffle=True) def set_opt(self, optim, learning_rate): # Optimizer if 'adam' == optim: opt = Adam(self.model.parameters(), lr=learning_rate) else: raise Exception('Fix me!') return opt def set_reconst_loss(self): loss_x, loss_w = None, None loss_x = BCELoss(reduction='sum') if 'JMVAE' == self.model.whoami: loss_w = NLLLoss(reduction='sum') return loss_x, loss_w def set_weight(self, num_epoch, num_batch): self.anneal_param = 1 / (2/5 * num_epoch * num_batch) def get_weight(self, epoch, step): return self.anneal_param * (epoch * self.num_batch + step) def save_s(self, epoch): z = torch.randn(10, self.model.z_dim).to(self.device) if 'CVAE' == self.model.whoami: z = torch.cat(tensors=(z, torch.Tensor(np.identity(10))), dim=1) out = self.model.decoder(z) save_image(out['x'].view(-1, 1, 28, 28), os.path.join(self.sample_path, 'sampled-{}.png'.format(epoch+1))) def save_r(self, in_put, epoch): out, _, _, _ = self.model(in_put) x_concat = torch.cat((in_put['x'].view(-1, 1, 28, 28), out['x'].view(-1, 1, 28, 28)), dim=3) save_image(x_concat, os.path.join(self.sample_path, 'reconst-{}.png'.format(epoch+1))) if 'JMVAE' == self.model.whoami: f = open("./samples/reconst_w_{}.txt".format(self.ts), "a") f.write(" ".join(str(e) for e in np.argmax(in_put['w'], axis=1).detach().tolist())) f.write("\n") f.write(" ".join(str(e) for e in np.argmax(out['w'], axis=1).detach().tolist())) f.write("\n\n") f.close() def plot_mean(self, path): if not 2 == self.model.z_dim: raise Exception("Cannot float over 2 dimensions: model has {} dimension".format(self.model.z_dim)) self.model.eval() self.test_loader = DataLoader(dataset=self.test_data, batch_size=3000, shuffle=True) data, label = [], [] for i, (x, w, l) in enumerate(self.test_loader): if i == 0: x = x.to(device=self.device, dtype=torch.float).view(-1, self.input_dim) w = w.to(device=self.device, dtype=torch.float).view(-1, self.num_label) in_put = {'x': x, 'w': w} output, mean, log_var, _ = self.model(in_put) data = mean.detach().numpy() label = l.detach().numpy() color_iter = ['#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3', '#fdb462', '#b3de69', '#fccde5', '#d9d9d9', '#bc80bd', '#ccebc5', '#ffed6f'] for i in range(self.num_label): idx = np.where(label == i) plt.scatter(data[idx, 0], data[idx, 1], color=color_iter[i], label=i) plt.legend(loc='best') plt.title(self.model.whoami, fontsize=8) plt.savefig(path) plt.close()