def test_stack_rnn_cell(self): x, y = gen_data.random_training_set(batch_size=bz) d_model = 128 hidden_size = 16 stack_width = 10 stack_depth = 20 num_layers = 1 num_dir = 2 encoder = Encoder(gen_data.n_characters, d_model, gen_data.char2idx[gen_data.pad_symbol]) x = encoder(x) rnn_cells = [] in_dim = d_model cell_type = 'gru' for _ in range(num_layers): rnn_cells.append( StackRNNCell(in_dim, hidden_size, has_stack=True, unit_type=cell_type, stack_depth=stack_depth, stack_width=stack_width)) in_dim = hidden_size * num_dir rnn_cells = torch.nn.ModuleList(rnn_cells) h0 = init_hidden(num_layers=num_layers, batch_size=bz, hidden_size=hidden_size, num_dir=num_dir) c0 = init_hidden(num_layers=num_layers, batch_size=bz, hidden_size=hidden_size, num_dir=num_dir) s0 = init_stack(bz, stack_width, stack_depth) seq_length = x.shape[0] hidden_outs = torch.zeros(num_layers, num_dir, seq_length, bz, hidden_size) if cell_type == 'lstm': cell_outs = torch.zeros(num_layers, num_dir, seq_length, bz, hidden_size) assert 0 <= num_dir <= 2 for l in range(num_layers): for d in range(num_dir): h, c, stack = h0[l, d, :], c0[l, d, :], s0 if d == 0: indices = range(x.shape[0]) else: indices = reversed(range(x.shape[0])) for i in indices: x_t = x[i, :, :] hx, stack = rnn_cells[l](x_t, h, c, stack) if cell_type == 'lstm': hidden_outs[l, d, i, :, :] = hx[0] cell_outs[l, d, i, :, :] = hx[1] else: hidden_outs[l, d, i, :, :] = hx
def forward(self, inp, return_logits=False): """ Evaluates the input to determine its reward. Argument :param inp: list / tuple [0] Input from encoder of shape (seq_len, batch_size, embed_dim) [1] SMILES validity flag :param return_logits: :return: tensor Reward of shape (batch_size, 1) """ x = inp[0] seq_len, batch_size = x.shape[:2] # Project embedding to a low dimension space (assumes the input has a higher dimension) x = self.proj_net(x) # Construct initial states hidden = init_hidden(self.num_layers, batch_size, self.hidden_size, self.num_dir, x.device) hidden_ = init_hidden(1, batch_size, self.hidden_size, 1, x.device).view(batch_size, self.hidden_size) if self.has_cell: cell = init_cell(self.num_layers, batch_size, self.hidden_size, self.num_dir, x.device) cell_ = init_cell(1, batch_size, self.hidden_size, 1, x.device).view(batch_size, self.hidden_size) hidden, hidden_ = (hidden, cell), (hidden_, cell_) # Apply base rnn output, hidden = self.base_rnn(x, hidden) # Additive attention, see: http://arxiv.org/abs/1409.0473 if self.use_attention: for i in range(seq_len): h = hidden_[0] if self.has_cell else hidden_ s = h.unsqueeze(0).expand(seq_len, *h.shape) x_ = torch.cat([output, s], dim=-1) logits = self.attn_linear(x_.contiguous().view( -1, x_.shape[-1])) wts = torch.softmax(logits.view(seq_len, batch_size).t(), -1).unsqueeze(2) x_ = x_.permute(1, 2, 0) ctx = x_.bmm(wts).squeeze(dim=2) hidden_ = self.post_rnn(ctx, hidden_) rw_x = hidden_[0] if self.has_cell else hidden_ else: rw_x = output[-1] logits = rw_x if self.v_flag: rw_x = torch.cat([rw_x, inp[-1]], dim=-1) reward = self.reward_net(rw_x) if return_logits: return reward, logits return reward
def get_initial_states(batch_size, hidden_size, num_layers, stack_depth, stack_width, unit_type): hidden = init_hidden(num_layers=num_layers, batch_size=batch_size, hidden_size=hidden_size, num_dir=1, dvc=device) if unit_type == 'lstm': cell = init_cell(num_layers=num_layers, batch_size=batch_size, hidden_size=hidden_size, num_dir=1, dvc=device) else: cell = None stack = init_stack(batch_size, stack_width, stack_depth, dvc=device) return hidden, cell, stack
def forward(self, x): """ Critic net :param x: tensor x.shape structure is (seq. len, batch, dim) :return: tensor (seq_len/states, batch, 1) """ if isinstance(x, (list, tuple)): x = x[0] batch_size = x.shape[1] hidden = init_hidden(self.num_layers, batch_size, self.hidden_size, 2, x.device) if self.has_cell: cell = init_cell(self.num_layers, batch_size, self.hidden_size, 2, x.device) hidden = (hidden, cell) x, _ = self.rnn(x, hidden) x = self.linear(self.norm(x)) return x
def train(model, optimizer, gen_data, rnn_args, n_iters=5000, sim_data_node=None, epoch_ckpt=(1, 2.0), tb_writer=None, is_hsearch=False): tb_writer = None # tb_writer() start = time.time() best_model_wts = model.state_dict() best_score = -10000 best_epoch = -1 terminate_training = False e_avg = ExpAverage(.01) num_batches = math.ceil(gen_data.file_len / gen_data.batch_size) n_epochs = math.ceil(n_iters / num_batches) grad_stats = GradStats(model, beta=0.) # learning rate decay schedulers # scheduler = sch.StepLR(optimizer, step_size=500, gamma=0.01) # pred_loss functions criterion = nn.CrossEntropyLoss( ignore_index=gen_data.char2idx[gen_data.pad_symbol]) # sub-nodes of sim data resource loss_lst = [] train_loss_node = DataNode(label="train_loss", data=loss_lst) metrics_dict = {} metrics_node = DataNode(label="validation_metrics", data=metrics_dict) train_scores_lst = [] train_scores_node = DataNode(label="train_score", data=train_scores_lst) scores_lst = [] scores_node = DataNode(label="validation_score", data=scores_lst) # add sim data nodes to parent node if sim_data_node: sim_data_node.data = [ train_loss_node, train_scores_node, metrics_node, scores_node ] try: # Main training loop tb_idx = {'train': Count(), 'val': Count(), 'test': Count()} epoch_losses = [] epoch_scores = [] for epoch in range(6): phase = 'train' # Iterate through mini-batches # with TBMeanTracker(tb_writer, 10) as tracker: with grad_stats: for b in trange(0, num_batches, desc=f'{phase} in progress...'): inputs, labels = gen_data.random_training_set() batch_size, seq_len = inputs.shape[:2] optimizer.zero_grad() # track history if only in train with torch.set_grad_enabled(phase == "train"): # Create hidden states for each layer hidden_states = [] for _ in range(rnn_args['num_layers']): hidden = init_hidden( num_layers=1, batch_size=batch_size, hidden_size=rnn_args['hidden_size'], num_dir=rnn_args['num_dir'], dvc=rnn_args['device']) if rnn_args['has_cell']: cell = init_cell( num_layers=1, batch_size=batch_size, hidden_size=rnn_args['hidden_size'], num_dir=rnn_args['num_dir'], dvc=rnn_args['device']) else: cell = None if rnn_args['has_stack']: stack = init_stack(batch_size, rnn_args['stack_width'], rnn_args['stack_depth'], dvc=rnn_args['device']) else: stack = None hidden_states.append((hidden, cell, stack)) # forward propagation outputs = model([inputs] + hidden_states) predictions = outputs[0] predictions = predictions.permute(1, 0, -1) predictions = predictions.contiguous().view( -1, predictions.shape[-1]) labels = labels.contiguous().view(-1) # calculate loss loss = criterion(predictions, labels) # metrics eval_dict = {} score = IreleasePretrain.evaluate( eval_dict, predictions, labels) # TBoard info # tracker.track("%s/loss" % phase, loss.item(), tb_idx[phase].IncAndGet()) # tracker.track("%s/score" % phase, score, tb_idx[phase].i) # for k in eval_dict: # tracker.track('{}/{}'.format(phase, k), eval_dict[k], tb_idx[phase].i) # backward pass loss.backward() optimizer.step() # for epoch stats epoch_losses.append(loss.item()) # for sim data resource train_scores_lst.append(score) loss_lst.append(loss.item()) # for epoch stats epoch_scores.append(score) print("\t{}: Epoch={}/{}, batch={}/{}, " "pred_loss={:.4f}, accuracy: {:.2f}, sample: {}". format( time_since(start), epoch + 1, n_epochs, b + 1, num_batches, loss.item(), eval_dict['accuracy'], generate_smiles(generator=model, gen_data=gen_data, init_args=rnn_args, num_samples=1))) IreleasePretrain.save_model( model, './model_dir/', name=f'irelease-pretrained_stack-rnn_gru_' f'{date_label}_epoch_{epoch}') # End of mini=batch iterations. except RuntimeError as e: print(str(e)) duration = time.time() - start print('\nModel training duration: {:.0f}m {:.0f}s'.format( duration // 60, duration % 60)) return { 'model': model, 'score': round(np.mean(epoch_scores), 3), 'epoch': n_epochs }
def train(generator, optimizer, rnn_args, pretrained_net_path=None, pretrained_net_name=None, n_iters=5000, sim_data_node=None, tb_writer=None, is_hsearch=False, is_pretraining=True, grad_clipping=None): expert_model = rnn_args['expert_model'] tb_writer = tb_writer() best_model_wts = generator.state_dict() best_score = -1000 best_epoch = -1 demo_data_gen = rnn_args['demo_data_gen'] unbiased_data_gen = rnn_args['unbiased_data_gen'] prior_data_gen = rnn_args['prior_data_gen'] score_exp_avg = ExpAverage(beta=0.6) exp_type = rnn_args['exp_type'] if is_pretraining: num_batches = math.ceil(prior_data_gen.file_len / prior_data_gen.batch_size) else: num_batches = math.ceil(demo_data_gen.file_len / demo_data_gen.batch_size) n_epochs = math.ceil(n_iters / num_batches) grad_stats = GradStats(generator, beta=0.) # learning rate decay schedulers # scheduler = sch.StepLR(optimizer, step_size=500, gamma=0.01) # pred_loss functions criterion = nn.CrossEntropyLoss( ignore_index=prior_data_gen.char2idx[prior_data_gen.pad_symbol]) # sub-nodes of sim data resource loss_lst = [] train_loss_node = DataNode(label="train_loss", data=loss_lst) # collect mean predictions unbiased_smiles_mean_pred, biased_smiles_mean_pred, gen_smiles_mean_pred = [], [], [] unbiased_smiles_mean_pred_data_node = DataNode( 'baseline_mean_vals', unbiased_smiles_mean_pred) biased_smiles_mean_pred_data_node = DataNode('biased_mean_vals', biased_smiles_mean_pred) gen_smiles_mean_pred_data_node = DataNode('gen_mean_vals', gen_smiles_mean_pred) if sim_data_node: sim_data_node.data = [ unbiased_smiles_mean_pred_data_node, biased_smiles_mean_pred_data_node, gen_smiles_mean_pred_data_node, train_loss_node ] # load pretrained model if pretrained_net_path and pretrained_net_name: print('Loading pretrained model...') weights = StackRNNBaseline.load_model(pretrained_net_path, pretrained_net_name) generator.load_state_dict(weights) print('Pretrained model loaded successfully!') start = time.time() try: demo_score = np.mean( expert_model( demo_data_gen.random_training_set_smiles(1000))[1]) baseline_score = np.mean( expert_model( unbiased_data_gen.random_training_set_smiles(1000))[1]) step_idx = Count() gen_data = prior_data_gen if is_pretraining else demo_data_gen n_epochs = 2 with TBMeanTracker(tb_writer, 1) as tracker: mode = 'Pretraining' if is_pretraining else 'Fine tuning' for epoch in range(n_epochs): epoch_losses = [] epoch_mean_preds = [] epoch_per_valid = [] with grad_stats: for b in trange( 0, num_batches, desc= f'Epoch {epoch + 1}/{n_epochs}, {mode} in progress...' ): inputs, labels = gen_data.random_training_set() inputs = inputs.to(device) labels = labels.to(device) batch_size, seq_len = inputs.shape[:2] optimizer.zero_grad() # track history if only in train # with torch.set_grad_enabled(phase == "train"): # Create hidden states for each layer hidden_states = [] for _ in range(rnn_args['num_layers']): hidden = init_hidden( num_layers=1, batch_size=batch_size, hidden_size=rnn_args['hidden_size'], num_dir=rnn_args['num_dir'], dvc=rnn_args['device']) if rnn_args['has_cell']: cell = init_cell( num_layers=1, batch_size=batch_size, hidden_size=rnn_args['hidden_size'], num_dir=rnn_args['num_dir'], dvc=rnn_args['device']) else: cell = None if rnn_args['has_stack']: stack = init_stack(batch_size, rnn_args['stack_width'], rnn_args['stack_depth'], dvc=rnn_args['device']) else: stack = None hidden_states.append((hidden, cell, stack)) # forward propagation outputs = generator([inputs] + hidden_states) predictions = outputs[0] predictions = predictions.permute(1, 0, -1) predictions = predictions.contiguous().view( -1, predictions.shape[-1]) labels = labels.contiguous().view(-1) # calculate loss loss = criterion(predictions, labels) if grad_clipping: torch.nn.utils.clip_grad_norm_( generator.parameters(), grad_clipping) optimizer.step() # for sim data resource n_to_generate = 200 with torch.set_grad_enabled(False): samples = generate_smiles( generator, demo_data_gen, rnn_args, num_samples=n_to_generate) samples_pred = expert_model(samples)[1] # metrics eval_dict = {} eval_score = StackRNNBaseline.evaluate( eval_dict, samples, demo_data_gen.random_training_set_smiles(1000)) # TBoard info tracker.track('loss', loss.item(), step_idx.IncAndGet()) for k in eval_dict: tracker.track(f'{k}', eval_dict[k], step_idx.i) mean_preds = np.mean(samples_pred) epoch_mean_preds.append(mean_preds) if exp_type == 'drd2': per_qualified = float( len([v for v in samples_pred if v >= 0.8 ])) / len(samples_pred) score = mean_preds elif exp_type == 'logp': per_qualified = np.sum( (samples_pred >= 1.0) & (samples_pred < 5.0)) / len(samples_pred) score = mean_preds elif exp_type == 'jak2_max': per_qualified = np.sum( (samples_pred >= demo_score)) / len(samples_pred) diff = mean_preds - demo_score score = np.exp(diff) elif exp_type == 'jak2_min': per_qualified = np.sum( (samples_pred <= demo_score)) / len(samples_pred) diff = demo_score - mean_preds score = np.exp(diff) else: # pretraining score = -loss.item() per_qualified = 0. per_valid = len(samples_pred) / n_to_generate epoch_per_valid.append(per_valid) unbiased_smiles_mean_pred.append( float(baseline_score)) biased_smiles_mean_pred.append(float(demo_score)) gen_smiles_mean_pred.append(float(mean_preds)) tb_writer.add_scalars( 'qsar_score', { 'sampled': mean_preds, 'baseline': baseline_score, 'demo_data': demo_score }, step_idx.i) tb_writer.add_scalars( 'SMILES stats', { 'per. of valid': per_valid, 'per. of qualified': per_qualified }, step_idx.i) avg_len = np.nanmean([len(s) for s in samples]) tracker.track('Average SMILES length', avg_len, step_idx.i) score_exp_avg.update(score) if score_exp_avg.value > best_score: best_model_wts = copy.deepcopy( generator.state_dict()) best_score = score_exp_avg.value best_epoch = epoch if step_idx.i > 0 and step_idx.i % 1000 == 0: smiles = generate_smiles(generator=generator, gen_data=gen_data, init_args=rnn_args, num_samples=3) print(f'Sample SMILES = {smiles}') # End of mini=batch iterations. print( f'{time_since(start)}: Epoch {epoch + 1}/{n_epochs}, loss={np.mean(epoch_losses)},' f'Mean value of predictions = {np.mean(epoch_mean_preds)}, ' f'% of valid SMILES = {np.mean(epoch_per_valid)}') except ValueError as e: print(str(e)) duration = time.time() - start print('Model training duration: {:.0f}m {:.0f}s'.format( duration // 60, duration % 60)) generator.load_state_dict(best_model_wts) return { 'model': generator, 'score': round(best_score, 3), 'epoch': best_epoch }