Ejemplo n.º 1
0
    def forward(self, inp, return_logits=False):
        """
        Evaluates the input to determine its reward.

        Argument
        :param inp: list / tuple
           [0] Input from encoder of shape (seq_len, batch_size, embed_dim)
           [1] SMILES validity flag
        :param return_logits:
        :return: tensor
            Reward of shape (batch_size, 1)
        """
        x = inp[0]
        seq_len, batch_size = x.shape[:2]

        # Project embedding to a low dimension space (assumes the input has a higher dimension)
        x = self.proj_net(x)

        # Construct initial states
        hidden = init_hidden(self.num_layers, batch_size, self.hidden_size,
                             self.num_dir, x.device)
        hidden_ = init_hidden(1, batch_size, self.hidden_size, 1,
                              x.device).view(batch_size, self.hidden_size)
        if self.has_cell:
            cell = init_cell(self.num_layers, batch_size, self.hidden_size,
                             self.num_dir, x.device)
            cell_ = init_cell(1, batch_size, self.hidden_size, 1,
                              x.device).view(batch_size, self.hidden_size)
            hidden, hidden_ = (hidden, cell), (hidden_, cell_)

        # Apply base rnn
        output, hidden = self.base_rnn(x, hidden)

        # Additive attention, see: http://arxiv.org/abs/1409.0473
        if self.use_attention:
            for i in range(seq_len):
                h = hidden_[0] if self.has_cell else hidden_
                s = h.unsqueeze(0).expand(seq_len, *h.shape)
                x_ = torch.cat([output, s], dim=-1)
                logits = self.attn_linear(x_.contiguous().view(
                    -1, x_.shape[-1]))
                wts = torch.softmax(logits.view(seq_len, batch_size).t(),
                                    -1).unsqueeze(2)
                x_ = x_.permute(1, 2, 0)
                ctx = x_.bmm(wts).squeeze(dim=2)
                hidden_ = self.post_rnn(ctx, hidden_)
            rw_x = hidden_[0] if self.has_cell else hidden_
        else:
            rw_x = output[-1]
        logits = rw_x
        if self.v_flag:
            rw_x = torch.cat([rw_x, inp[-1]], dim=-1)
        reward = self.reward_net(rw_x)
        if return_logits:
            return reward, logits
        return reward
Ejemplo n.º 2
0
def get_initial_states(batch_size, hidden_size, num_layers, stack_depth, stack_width, unit_type):
    hidden = init_hidden(num_layers=num_layers, batch_size=batch_size, hidden_size=hidden_size, num_dir=1,
                         dvc=device)
    if unit_type == 'lstm':
        cell = init_cell(num_layers=num_layers, batch_size=batch_size, hidden_size=hidden_size, num_dir=1,
                         dvc=device)
    else:
        cell = None
    stack = init_stack(batch_size, stack_width, stack_depth, dvc=device)
    return hidden, cell, stack
Ejemplo n.º 3
0
 def forward(self, x):
     """
     Critic net
     :param x: tensor
         x.shape structure is (seq. len, batch, dim)
     :return: tensor
         (seq_len/states, batch, 1)
     """
     if isinstance(x, (list, tuple)):
         x = x[0]
     batch_size = x.shape[1]
     hidden = init_hidden(self.num_layers, batch_size, self.hidden_size, 2,
                          x.device)
     if self.has_cell:
         cell = init_cell(self.num_layers, batch_size, self.hidden_size, 2,
                          x.device)
         hidden = (hidden, cell)
     x, _ = self.rnn(x, hidden)
     x = self.linear(self.norm(x))
     return x
Ejemplo n.º 4
0
    def train(model,
              optimizer,
              gen_data,
              rnn_args,
              n_iters=5000,
              sim_data_node=None,
              epoch_ckpt=(1, 2.0),
              tb_writer=None,
              is_hsearch=False):
        tb_writer = None  # tb_writer()
        start = time.time()
        best_model_wts = model.state_dict()
        best_score = -10000
        best_epoch = -1
        terminate_training = False
        e_avg = ExpAverage(.01)
        num_batches = math.ceil(gen_data.file_len / gen_data.batch_size)
        n_epochs = math.ceil(n_iters / num_batches)
        grad_stats = GradStats(model, beta=0.)

        # learning rate decay schedulers
        # scheduler = sch.StepLR(optimizer, step_size=500, gamma=0.01)

        # pred_loss functions
        criterion = nn.CrossEntropyLoss(
            ignore_index=gen_data.char2idx[gen_data.pad_symbol])

        # sub-nodes of sim data resource
        loss_lst = []
        train_loss_node = DataNode(label="train_loss", data=loss_lst)
        metrics_dict = {}
        metrics_node = DataNode(label="validation_metrics", data=metrics_dict)
        train_scores_lst = []
        train_scores_node = DataNode(label="train_score",
                                     data=train_scores_lst)
        scores_lst = []
        scores_node = DataNode(label="validation_score", data=scores_lst)

        # add sim data nodes to parent node
        if sim_data_node:
            sim_data_node.data = [
                train_loss_node, train_scores_node, metrics_node, scores_node
            ]

        try:
            # Main training loop
            tb_idx = {'train': Count(), 'val': Count(), 'test': Count()}
            epoch_losses = []
            epoch_scores = []
            for epoch in range(6):
                phase = 'train'

                # Iterate through mini-batches
                # with TBMeanTracker(tb_writer, 10) as tracker:
                with grad_stats:
                    for b in trange(0,
                                    num_batches,
                                    desc=f'{phase} in progress...'):
                        inputs, labels = gen_data.random_training_set()
                        batch_size, seq_len = inputs.shape[:2]
                        optimizer.zero_grad()

                        # track history if only in train
                        with torch.set_grad_enabled(phase == "train"):
                            # Create hidden states for each layer
                            hidden_states = []
                            for _ in range(rnn_args['num_layers']):
                                hidden = init_hidden(
                                    num_layers=1,
                                    batch_size=batch_size,
                                    hidden_size=rnn_args['hidden_size'],
                                    num_dir=rnn_args['num_dir'],
                                    dvc=rnn_args['device'])
                                if rnn_args['has_cell']:
                                    cell = init_cell(
                                        num_layers=1,
                                        batch_size=batch_size,
                                        hidden_size=rnn_args['hidden_size'],
                                        num_dir=rnn_args['num_dir'],
                                        dvc=rnn_args['device'])
                                else:
                                    cell = None
                                if rnn_args['has_stack']:
                                    stack = init_stack(batch_size,
                                                       rnn_args['stack_width'],
                                                       rnn_args['stack_depth'],
                                                       dvc=rnn_args['device'])
                                else:
                                    stack = None
                                hidden_states.append((hidden, cell, stack))
                            # forward propagation
                            outputs = model([inputs] + hidden_states)
                            predictions = outputs[0]
                            predictions = predictions.permute(1, 0, -1)
                            predictions = predictions.contiguous().view(
                                -1, predictions.shape[-1])
                            labels = labels.contiguous().view(-1)

                            # calculate loss
                            loss = criterion(predictions, labels)

                        # metrics
                        eval_dict = {}
                        score = IreleasePretrain.evaluate(
                            eval_dict, predictions, labels)

                        # TBoard info
                        # tracker.track("%s/loss" % phase, loss.item(), tb_idx[phase].IncAndGet())
                        # tracker.track("%s/score" % phase, score, tb_idx[phase].i)
                        # for k in eval_dict:
                        #     tracker.track('{}/{}'.format(phase, k), eval_dict[k], tb_idx[phase].i)

                        # backward pass
                        loss.backward()
                        optimizer.step()

                        # for epoch stats
                        epoch_losses.append(loss.item())

                        # for sim data resource
                        train_scores_lst.append(score)
                        loss_lst.append(loss.item())

                        # for epoch stats
                        epoch_scores.append(score)

                        print("\t{}: Epoch={}/{}, batch={}/{}, "
                              "pred_loss={:.4f}, accuracy: {:.2f}, sample: {}".
                              format(
                                  time_since(start),
                                  epoch + 1, n_epochs, b + 1, num_batches,
                                  loss.item(), eval_dict['accuracy'],
                                  generate_smiles(generator=model,
                                                  gen_data=gen_data,
                                                  init_args=rnn_args,
                                                  num_samples=1)))
                    IreleasePretrain.save_model(
                        model,
                        './model_dir/',
                        name=f'irelease-pretrained_stack-rnn_gru_'
                        f'{date_label}_epoch_{epoch}')
                # End of mini=batch iterations.
        except RuntimeError as e:
            print(str(e))

        duration = time.time() - start
        print('\nModel training duration: {:.0f}m {:.0f}s'.format(
            duration // 60, duration % 60))
        return {
            'model': model,
            'score': round(np.mean(epoch_scores), 3),
            'epoch': n_epochs
        }
    def train(generator,
              optimizer,
              rnn_args,
              pretrained_net_path=None,
              pretrained_net_name=None,
              n_iters=5000,
              sim_data_node=None,
              tb_writer=None,
              is_hsearch=False,
              is_pretraining=True,
              grad_clipping=None):
        expert_model = rnn_args['expert_model']
        tb_writer = tb_writer()
        best_model_wts = generator.state_dict()
        best_score = -1000
        best_epoch = -1
        demo_data_gen = rnn_args['demo_data_gen']
        unbiased_data_gen = rnn_args['unbiased_data_gen']
        prior_data_gen = rnn_args['prior_data_gen']
        score_exp_avg = ExpAverage(beta=0.6)
        exp_type = rnn_args['exp_type']

        if is_pretraining:
            num_batches = math.ceil(prior_data_gen.file_len /
                                    prior_data_gen.batch_size)
        else:
            num_batches = math.ceil(demo_data_gen.file_len /
                                    demo_data_gen.batch_size)
        n_epochs = math.ceil(n_iters / num_batches)
        grad_stats = GradStats(generator, beta=0.)

        # learning rate decay schedulers
        # scheduler = sch.StepLR(optimizer, step_size=500, gamma=0.01)

        # pred_loss functions
        criterion = nn.CrossEntropyLoss(
            ignore_index=prior_data_gen.char2idx[prior_data_gen.pad_symbol])

        # sub-nodes of sim data resource
        loss_lst = []
        train_loss_node = DataNode(label="train_loss", data=loss_lst)

        # collect mean predictions
        unbiased_smiles_mean_pred, biased_smiles_mean_pred, gen_smiles_mean_pred = [], [], []
        unbiased_smiles_mean_pred_data_node = DataNode(
            'baseline_mean_vals', unbiased_smiles_mean_pred)
        biased_smiles_mean_pred_data_node = DataNode('biased_mean_vals',
                                                     biased_smiles_mean_pred)
        gen_smiles_mean_pred_data_node = DataNode('gen_mean_vals',
                                                  gen_smiles_mean_pred)
        if sim_data_node:
            sim_data_node.data = [
                unbiased_smiles_mean_pred_data_node,
                biased_smiles_mean_pred_data_node,
                gen_smiles_mean_pred_data_node, train_loss_node
            ]

        # load pretrained model
        if pretrained_net_path and pretrained_net_name:
            print('Loading pretrained model...')
            weights = StackRNNBaseline.load_model(pretrained_net_path,
                                                  pretrained_net_name)
            generator.load_state_dict(weights)
            print('Pretrained model loaded successfully!')

        start = time.time()
        try:
            demo_score = np.mean(
                expert_model(
                    demo_data_gen.random_training_set_smiles(1000))[1])
            baseline_score = np.mean(
                expert_model(
                    unbiased_data_gen.random_training_set_smiles(1000))[1])
            step_idx = Count()
            gen_data = prior_data_gen if is_pretraining else demo_data_gen
            n_epochs = 2
            with TBMeanTracker(tb_writer, 1) as tracker:
                mode = 'Pretraining' if is_pretraining else 'Fine tuning'
                for epoch in range(n_epochs):
                    epoch_losses = []
                    epoch_mean_preds = []
                    epoch_per_valid = []
                    with grad_stats:
                        for b in trange(
                                0,
                                num_batches,
                                desc=
                                f'Epoch {epoch + 1}/{n_epochs}, {mode} in progress...'
                        ):
                            inputs, labels = gen_data.random_training_set()
                            inputs = inputs.to(device)
                            labels = labels.to(device)
                            batch_size, seq_len = inputs.shape[:2]
                            optimizer.zero_grad()

                            # track history if only in train
                            # with torch.set_grad_enabled(phase == "train"):
                            # Create hidden states for each layer
                            hidden_states = []
                            for _ in range(rnn_args['num_layers']):
                                hidden = init_hidden(
                                    num_layers=1,
                                    batch_size=batch_size,
                                    hidden_size=rnn_args['hidden_size'],
                                    num_dir=rnn_args['num_dir'],
                                    dvc=rnn_args['device'])
                                if rnn_args['has_cell']:
                                    cell = init_cell(
                                        num_layers=1,
                                        batch_size=batch_size,
                                        hidden_size=rnn_args['hidden_size'],
                                        num_dir=rnn_args['num_dir'],
                                        dvc=rnn_args['device'])
                                else:
                                    cell = None
                                if rnn_args['has_stack']:
                                    stack = init_stack(batch_size,
                                                       rnn_args['stack_width'],
                                                       rnn_args['stack_depth'],
                                                       dvc=rnn_args['device'])
                                else:
                                    stack = None
                                hidden_states.append((hidden, cell, stack))

                            # forward propagation
                            outputs = generator([inputs] + hidden_states)
                            predictions = outputs[0]
                            predictions = predictions.permute(1, 0, -1)
                            predictions = predictions.contiguous().view(
                                -1, predictions.shape[-1])
                            labels = labels.contiguous().view(-1)

                            # calculate loss
                            loss = criterion(predictions, labels)

                            if grad_clipping:
                                torch.nn.utils.clip_grad_norm_(
                                    generator.parameters(), grad_clipping)
                            optimizer.step()

                            # for sim data resource
                            n_to_generate = 200
                            with torch.set_grad_enabled(False):
                                samples = generate_smiles(
                                    generator,
                                    demo_data_gen,
                                    rnn_args,
                                    num_samples=n_to_generate)
                            samples_pred = expert_model(samples)[1]

                            # metrics
                            eval_dict = {}
                            eval_score = StackRNNBaseline.evaluate(
                                eval_dict, samples,
                                demo_data_gen.random_training_set_smiles(1000))
                            # TBoard info
                            tracker.track('loss', loss.item(),
                                          step_idx.IncAndGet())
                            for k in eval_dict:
                                tracker.track(f'{k}', eval_dict[k], step_idx.i)
                            mean_preds = np.mean(samples_pred)
                            epoch_mean_preds.append(mean_preds)
                            if exp_type == 'drd2':
                                per_qualified = float(
                                    len([v for v in samples_pred if v >= 0.8
                                         ])) / len(samples_pred)
                                score = mean_preds
                            elif exp_type == 'logp':
                                per_qualified = np.sum(
                                    (samples_pred >= 1.0)
                                    & (samples_pred < 5.0)) / len(samples_pred)
                                score = mean_preds
                            elif exp_type == 'jak2_max':
                                per_qualified = np.sum(
                                    (samples_pred >=
                                     demo_score)) / len(samples_pred)
                                diff = mean_preds - demo_score
                                score = np.exp(diff)
                            elif exp_type == 'jak2_min':
                                per_qualified = np.sum(
                                    (samples_pred <=
                                     demo_score)) / len(samples_pred)
                                diff = demo_score - mean_preds
                                score = np.exp(diff)
                            else:  # pretraining
                                score = -loss.item()
                                per_qualified = 0.
                            per_valid = len(samples_pred) / n_to_generate
                            epoch_per_valid.append(per_valid)
                            unbiased_smiles_mean_pred.append(
                                float(baseline_score))
                            biased_smiles_mean_pred.append(float(demo_score))
                            gen_smiles_mean_pred.append(float(mean_preds))
                            tb_writer.add_scalars(
                                'qsar_score', {
                                    'sampled': mean_preds,
                                    'baseline': baseline_score,
                                    'demo_data': demo_score
                                }, step_idx.i)
                            tb_writer.add_scalars(
                                'SMILES stats', {
                                    'per. of valid': per_valid,
                                    'per. of qualified': per_qualified
                                }, step_idx.i)
                            avg_len = np.nanmean([len(s) for s in samples])
                            tracker.track('Average SMILES length', avg_len,
                                          step_idx.i)

                            score_exp_avg.update(score)
                            if score_exp_avg.value > best_score:
                                best_model_wts = copy.deepcopy(
                                    generator.state_dict())
                                best_score = score_exp_avg.value
                                best_epoch = epoch

                            if step_idx.i > 0 and step_idx.i % 1000 == 0:
                                smiles = generate_smiles(generator=generator,
                                                         gen_data=gen_data,
                                                         init_args=rnn_args,
                                                         num_samples=3)
                                print(f'Sample SMILES = {smiles}')
                        # End of mini=batch iterations.
                        print(
                            f'{time_since(start)}: Epoch {epoch + 1}/{n_epochs}, loss={np.mean(epoch_losses)},'
                            f'Mean value of predictions = {np.mean(epoch_mean_preds)}, '
                            f'% of valid SMILES = {np.mean(epoch_per_valid)}')

        except ValueError as e:
            print(str(e))

        duration = time.time() - start
        print('Model training duration: {:.0f}m {:.0f}s'.format(
            duration // 60, duration % 60))
        generator.load_state_dict(best_model_wts)
        return {
            'model': generator,
            'score': round(best_score, 3),
            'epoch': best_epoch
        }