def main(args): def interpolate(start, end, steps): interpolation = np.zeros((start.shape[0], steps + 2)) for dim, (s, e) in enumerate(zip(start, end)): interpolation[dim] = np.linspace(s, e, steps + 2) return interpolation.T def idx2word(sent_list, i2w, pad_idx): sent = [] for s in sent_list: sent.append(" ".join([i2w[str(int(idx))] \ for idx in s if int(idx) is not pad_idx])) return sent with open(args.data_dir + '/vocab.json', 'r') as file: vocab = json.load(file) w2i, i2w = vocab['w2i'], vocab['i2w'] #Load model model = SVAE( vocab_size=len(w2i), embed_dim=args.embedding_dimension, hidden_dim=args.hidden_dimension, latent_dim=args.latent_dimension, teacher_forcing=False, dropout=args.dropout, n_direction=(2 if args.bidirectional else 1), n_parallel=args.n_layer, max_src_len=args.max_src_length, #influence in inference stage max_tgt_len=args.max_tgt_length, sos_idx=w2i['<sos>'], eos_idx=w2i['<eos>'], pad_idx=w2i['<pad>'], unk_idx=w2i['<unk>'], ) path = os.path.join('checkpoint', args.load_checkpoint) if not os.path.exists(path): raise FileNotFoundError(path) model.load_state_dict(torch.load(path)) print("Model loaded from %s" % (path)) if torch.cuda.is_available(): model = model.cuda() model.eval() samples, z = model.inference(n=args.num_samples) print('----------SAMPLES----------') print(*idx2word(sent_list=samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n') z1 = torch.randn([args.latent_dimension]).numpy() z2 = torch.randn([args.latent_dimension]).numpy() z = torch.from_numpy(interpolate(start=z1, end=z2, steps=8)).float() samples, _ = model.inference(z=z) print('-------INTERPOLATION-------') print(*idx2word(sent_list=samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')
class ModularTrainer(Sampler): def __init__(self): self.config = load_config() self.model_config = self.config['Models'] self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") # Model self.task_type = self.config['Utils']['task_type'] self.max_sequence_length = self.config['Utils'][ self.task_type]['max_sequence_length'] # Real data self.data_name = self.config['Utils'][self.task_type]['data_name'] self.data_splits = self.config['Utils'][self.task_type]['data_split'] self.pad_idx = self.config['Utils']['special_token2idx']['<PAD>'] # Test run properties self.epochs = self.config['Train']['epochs'] self.svae_iterations = self.config['Train']['svae_iterations'] self.kfold_xval = False def _init_data(self, batch_size=None): if batch_size is None: batch_size = self.config['Train']['batch_size'] # Load pre-processed data path_data = os.path.join('/home/tyler/Desktop/Repos/s-vaal/data', self.task_type, self.data_name, 'pretrain', 'data.json') path_vocab = os.path.join('/home/tyler/Desktop/Repos/s-vaal/data', self.task_type, self.data_name, 'pretrain', 'vocab.json') # not vocabs data = load_json(path_data) self.vocab = load_json( path_vocab ) # Required for decoding sequences for interpretations. TODO: Find suitable location... or leave be... self.vocab_size = len(self.vocab['word2idx']) self.idx2word = self.vocab['idx2word'] self.word2idx = self.vocab['word2idx'] self.datasets = dict() if self.kfold_xval: # Perform k-fold cross-validation # Join all datasets and then randomly assign train/val/test print('hello') for split in self.data_splits: print(data[split][self.x_y_pair_name]) else: for split in self.data_splits: # Access data split_data = data[split] # print(split_data) # Convert lists of encoded sequences into tensors and stack into one large tensor split_inputs = torch.stack([ torch.tensor(value['input']) for key, value in split_data.items() ]) split_targets = torch.stack([ torch.tensor(value['target']) for key, value in split_data.items() ]) # Create torch dataset from tensors split_dataset = RealDataset(sequences=split_inputs, tags=split_targets) # Add to dictionary self.datasets[split] = split_dataset #split_dataloader # Create torch dataloader generator from dataset if split == 'test': self.test_dataloader = DataLoader(dataset=split_dataset, batch_size=batch_size, shuffle=True, num_workers=0) if split == 'valid': self.val_dataloader = DataLoader(dataset=split_dataset, batch_size=batch_size, shuffle=True, num_workers=0) if split == 'test': self.train_dataloader = DataLoader(dataset=split_dataset, batch_size=batch_size, shuffle=True, num_workers=0) print(f'{datetime.now()}: Data loaded succesfully') def _init_svae_model(self): self.svae = SVAE(**self.model_config['SVAE']['Parameters'], vocab_size=self.vocab_size).to(self.device) self.svae_optim = optim.Adam( self.svae.parameters(), lr=self.model_config['SVAE']['learning_rate']) self.svae.train() print(f'{datetime.now()}: Initialised SVAE successfully') def interpolate(self, start, end, steps): interpolation = np.zeros((start.shape[0], steps + 2)) for dim, (s, e) in enumerate(zip(start, end)): interpolation[dim] = np.linspace(s, e, steps + 2) return interpolation.T def _idx2word_inf(self, idx, i2w, pad_idx): # inf-erence sent_str = [str()] * len(idx) for i, sent in enumerate(idx): for word_id in sent: if word_id == pad_idx: break sent_str[i] += i2w[str(word_id.item())] + " " sent_str[i] = sent_str[i].strip() return sent_str def _pretrain_svae(self): self._init_data() self._init_svae_model() tb_writer = SummaryWriter( comment=f"pretrain svae {self.data_name}", filename_suffix=f"pretrain svae {self.data_name}") print(f'{datetime.now()}: Training started') step = 0 for epoch in range(1, self.config['Train']['epochs'] + 1, 1): for batch_inputs, batch_lengths, batch_targets in self.train_dataloader: if torch.cuda.is_available(): batch_inputs = batch_inputs.to(self.device) batch_lengths = batch_lengths.to(self.device) batch_targets = batch_targets.to(self.device) batch_size = batch_inputs.size(0) logp, mean, logv, _ = self.svae(batch_inputs, batch_lengths, pretrain=False) NLL_loss, KL_loss, KL_weight = self.svae.loss_fn( logp=logp, target=batch_targets, length=batch_lengths, mean=mean, logv=logv, anneal_fn=self.model_config['SVAE']['anneal_function'], step=step, k=self.model_config['SVAE']['k'], x0=self.model_config['SVAE']['x0']) svae_loss = (NLL_loss + KL_weight * KL_loss) / batch_size self.svae_optim.zero_grad() svae_loss.backward() self.svae_optim.step() tb_writer.add_scalar('Loss/train/KLL', KL_loss, step) tb_writer.add_scalar('Loss/train/NLL', NLL_loss, step) tb_writer.add_scalar('Loss/train/Total', svae_loss, step) tb_writer.add_scalar('Utils/train/KL_weight', KL_weight, step) # Increment step after each batch of data step += 1 if epoch % 1 == 0: print( f'{datetime.now()}: Epoch {epoch} Loss {svae_loss:0.2f} Step {step}' ) if epoch % 5 == 0: # Perform inference self.svae.eval() try: samples, z = self.svae.inference(n=2) print(*self._idx2word_inf(samples, i2w=self.idx2word, pad_idx=self.config['Utils'] ['special_token2idx']['<PAD>']), sep='\n') except: traceback.print_exc(file=sys.stdout) self.svae.train() # Save final model save_path = os.getcwd() + '/best models/svae.pt' torch.save(self.svae.state_dict(), save_path) print(f'{datetime.now()}: Model saved') print(f'{datetime.now()}: Training finished')