Example #1
0
    def evaluate_model(model,
                       gen_data,
                       rnn_args,
                       sim_data_node=None,
                       num_smiles=1000):
        start = time.time()
        model.eval()

        # Samples SMILES
        samples = []
        step = 100
        count = 0
        for _ in range(int(num_smiles / step)):
            samples.extend(
                generate_smiles(generator=model,
                                gen_data=gen_data,
                                init_args=rnn_args,
                                num_samples=step,
                                is_train=False,
                                verbose=True,
                                max_len=smiles_max_len))
            count += step
        res = num_smiles - count
        if res > 0:
            samples.extend(
                generate_smiles(generator=model,
                                gen_data=gen_data,
                                init_args=rnn_args,
                                num_samples=res,
                                is_train=False,
                                verbose=True,
                                max_len=smiles_max_len))
        smiles, valid_vec = canonical_smiles(samples)
        valid_smiles = []
        invalid_smiles = []
        for idx, sm in enumerate(smiles):
            if len(sm) > 0:
                valid_smiles.append(sm)
            else:
                invalid_smiles.append(samples[idx])
        v = len(valid_smiles)
        valid_smiles = list(set(valid_smiles))
        print(
            f'Percentage of valid SMILES = {float(len(valid_smiles)) / float(len(samples)):.2f}, '
            f'Num. samples = {len(samples)}, Num. valid = {len(valid_smiles)}, '
            f'Num. requested = {num_smiles}, Num. dups = {v - len(valid_smiles)}'
        )

        # sub-nodes of sim data resource
        smiles_node = DataNode(label="valid_smiles", data=valid_smiles)
        invalid_smiles_node = DataNode(label='invalid_smiles',
                                       data=invalid_smiles)

        # add sim data nodes to parent node
        if sim_data_node:
            sim_data_node.data = [smiles_node, invalid_smiles_node]

        duration = time.time() - start
        print('\nModel evaluation duration: {:.0f}m {:.0f}s'.format(
            duration // 60, duration % 60))
    def train(init_args,
              agent_net_path=None,
              agent_net_name=None,
              seed=0,
              n_episodes=500,
              sim_data_node=None,
              tb_writer=None,
              is_hsearch=False,
              n_to_generate=200,
              learn_irl=True,
              bias_mode='max'):
        tb_writer = tb_writer()
        agent = init_args['agent']
        probs_reg = init_args['probs_reg']
        drl_algorithm = init_args['drl_alg']
        irl_algorithm = init_args['irl_alg']
        reward_func = init_args['reward_func']
        gamma = init_args['gamma']
        episodes_to_train = init_args['episodes_to_train']
        expert_model = init_args['expert_model']
        demo_data_gen = init_args['demo_data_gen']
        unbiased_data_gen = init_args['unbiased_data_gen']
        best_model_wts = None
        best_score = 0.
        exp_avg = ExpAverage(beta=0.6)

        # load pretrained model
        if agent_net_path and agent_net_name:
            print('Loading pretrained model...')
            agent.model.load_state_dict(
                IReLeaSE.load_model(agent_net_path, agent_net_name))
            print('Pretrained model loaded successfully!')

        # collect mean predictions
        unbiased_smiles_mean_pred, biased_smiles_mean_pred, gen_smiles_mean_pred = [], [], []
        unbiased_smiles_mean_pred_data_node = DataNode(
            'baseline_mean_vals', unbiased_smiles_mean_pred)
        biased_smiles_mean_pred_data_node = DataNode('biased_mean_vals',
                                                     biased_smiles_mean_pred)
        gen_smiles_mean_pred_data_node = DataNode('gen_mean_vals',
                                                  gen_smiles_mean_pred)
        if sim_data_node:
            sim_data_node.data = [
                unbiased_smiles_mean_pred_data_node,
                biased_smiles_mean_pred_data_node,
                gen_smiles_mean_pred_data_node
            ]

        start = time.time()

        # Begin simulation and training
        total_rewards = []
        irl_trajectories = []
        done_episodes = 0
        batch_episodes = 0
        exp_trajectories = []

        env = MoleculeEnv(actions=get_default_tokens(),
                          reward_func=reward_func)
        exp_source = ExperienceSourceFirstLast(env,
                                               agent,
                                               gamma,
                                               steps_count=1,
                                               steps_delta=1)
        traj_prob = 1.
        exp_traj = []

        demo_score = np.mean(
            expert_model(demo_data_gen.random_training_set_smiles(1000))[1])
        baseline_score = np.mean(
            expert_model(
                unbiased_data_gen.random_training_set_smiles(1000))[1])
        with contextlib.suppress(Exception if is_hsearch else DummyException):
            with TBMeanTracker(tb_writer, 1) as tracker:
                for step_idx, exp in tqdm(enumerate(exp_source)):
                    exp_traj.append(exp)
                    traj_prob *= probs_reg.get(list(exp.state), exp.action)

                    if exp.last_state is None:
                        irl_trajectories.append(
                            Trajectory(terminal_state=EpisodeStep(
                                exp.state, exp.action),
                                       traj_prob=traj_prob))
                        exp_trajectories.append(
                            exp_traj)  # for ExperienceFirstLast objects
                        exp_traj = []
                        traj_prob = 1.
                        probs_reg.clear()
                        batch_episodes += 1

                    new_rewards = exp_source.pop_total_rewards()
                    if new_rewards:
                        reward = new_rewards[0]
                        done_episodes += 1
                        total_rewards.append(reward)
                        mean_rewards = float(np.mean(total_rewards[-100:]))
                        tracker.track('mean_total_reward', mean_rewards,
                                      step_idx)
                        tracker.track('total_reward', reward, step_idx)
                        print(
                            f'Time = {time_since(start)}, step = {step_idx}, reward = {reward:6.2f}, '
                            f'mean_100 = {mean_rewards:6.2f}, episodes = {done_episodes}'
                        )
                        with torch.set_grad_enabled(False):
                            samples = generate_smiles(
                                drl_algorithm.model,
                                demo_data_gen,
                                init_args['gen_args'],
                                num_samples=n_to_generate)
                        predictions = expert_model(samples)[1]
                        mean_preds = np.mean(predictions)
                        try:
                            percentage_in_threshold = np.sum(
                                (predictions >= 7.0)) / len(predictions)
                        except:
                            percentage_in_threshold = 0.
                        per_valid = len(predictions) / n_to_generate
                        print(
                            f'Mean value of predictions = {mean_preds}, '
                            f'% of valid SMILES = {per_valid}, '
                            f'% in drug-like region={percentage_in_threshold}')
                        unbiased_smiles_mean_pred.append(float(baseline_score))
                        biased_smiles_mean_pred.append(float(demo_score))
                        gen_smiles_mean_pred.append(float(mean_preds))
                        tb_writer.add_scalars(
                            'qsar_score', {
                                'sampled': mean_preds,
                                'baseline': baseline_score,
                                'demo_data': demo_score
                            }, step_idx)
                        tb_writer.add_scalars(
                            'SMILES stats', {
                                'per. of valid': per_valid,
                                'per. above threshold': percentage_in_threshold
                            }, step_idx)
                        eval_dict = {}
                        eval_score = IReLeaSE.evaluate(
                            eval_dict, samples,
                            demo_data_gen.random_training_set_smiles(1000))

                        for k in eval_dict:
                            tracker.track(k, eval_dict[k], step_idx)
                        tracker.track('Average SMILES length',
                                      np.nanmean([len(s) for s in samples]),
                                      step_idx)
                        if bias_mode == 'max':
                            diff = mean_preds - demo_score
                        else:
                            diff = demo_score - mean_preds
                        score = np.exp(diff)
                        exp_avg.update(score)
                        tracker.track('score', score, step_idx)
                        if exp_avg.value > best_score:
                            best_model_wts = [
                                copy.deepcopy(
                                    drl_algorithm.model.state_dict()),
                                copy.deepcopy(irl_algorithm.model.state_dict())
                            ]
                            best_score = exp_avg.value
                        if best_score >= np.exp(0.):
                            print(
                                f'threshold reached, best score={mean_preds}, '
                                f'threshold={demo_score}, training completed')
                            break
                        if done_episodes == n_episodes:
                            print('Training completed!')
                            break

                    if batch_episodes < episodes_to_train:
                        continue

                    # Train models
                    print('Fitting models...')
                    irl_stmt = ''
                    if learn_irl:
                        irl_loss = irl_algorithm.fit(irl_trajectories)
                        tracker.track('irl_loss', irl_loss, step_idx)
                        irl_stmt = f'IRL loss = {irl_loss}, '
                    rl_loss = drl_algorithm.fit(exp_trajectories)
                    samples = generate_smiles(drl_algorithm.model,
                                              demo_data_gen,
                                              init_args['gen_args'],
                                              num_samples=3)
                    print(
                        f'{irl_stmt}RL loss = {rl_loss}, samples = {samples}')
                    tracker.track('agent_loss', rl_loss, step_idx)

                    # Reset
                    batch_episodes = 0
                    irl_trajectories.clear()
                    exp_trajectories.clear()

        if best_model_wts:
            drl_algorithm.model.load_state_dict(best_model_wts[0])
            irl_algorithm.model.load_state_dict(best_model_wts[1])
        duration = time.time() - start
        print('\nTraining duration: {:.0f}m {:.0f}s'.format(
            duration // 60, duration % 60))
        return {
            'model': [drl_algorithm.model, irl_algorithm.model],
            'score': round(best_score, 3),
            'epoch': done_episodes
        }
Example #3
0
    def train(model,
              optimizer,
              gen_data,
              rnn_args,
              n_iters=5000,
              sim_data_node=None,
              epoch_ckpt=(1, 2.0),
              tb_writer=None,
              is_hsearch=False):
        tb_writer = None  # tb_writer()
        start = time.time()
        best_model_wts = model.state_dict()
        best_score = -10000
        best_epoch = -1
        terminate_training = False
        e_avg = ExpAverage(.01)
        num_batches = math.ceil(gen_data.file_len / gen_data.batch_size)
        n_epochs = math.ceil(n_iters / num_batches)
        grad_stats = GradStats(model, beta=0.)

        # learning rate decay schedulers
        # scheduler = sch.StepLR(optimizer, step_size=500, gamma=0.01)

        # pred_loss functions
        criterion = nn.CrossEntropyLoss(
            ignore_index=gen_data.char2idx[gen_data.pad_symbol])

        # sub-nodes of sim data resource
        loss_lst = []
        train_loss_node = DataNode(label="train_loss", data=loss_lst)
        metrics_dict = {}
        metrics_node = DataNode(label="validation_metrics", data=metrics_dict)
        train_scores_lst = []
        train_scores_node = DataNode(label="train_score",
                                     data=train_scores_lst)
        scores_lst = []
        scores_node = DataNode(label="validation_score", data=scores_lst)

        # add sim data nodes to parent node
        if sim_data_node:
            sim_data_node.data = [
                train_loss_node, train_scores_node, metrics_node, scores_node
            ]

        try:
            # Main training loop
            tb_idx = {'train': Count(), 'val': Count(), 'test': Count()}
            epoch_losses = []
            epoch_scores = []
            for epoch in range(6):
                phase = 'train'

                # Iterate through mini-batches
                # with TBMeanTracker(tb_writer, 10) as tracker:
                with grad_stats:
                    for b in trange(0,
                                    num_batches,
                                    desc=f'{phase} in progress...'):
                        inputs, labels = gen_data.random_training_set()
                        batch_size, seq_len = inputs.shape[:2]
                        optimizer.zero_grad()

                        # track history if only in train
                        with torch.set_grad_enabled(phase == "train"):
                            # Create hidden states for each layer
                            hidden_states = []
                            for _ in range(rnn_args['num_layers']):
                                hidden = init_hidden(
                                    num_layers=1,
                                    batch_size=batch_size,
                                    hidden_size=rnn_args['hidden_size'],
                                    num_dir=rnn_args['num_dir'],
                                    dvc=rnn_args['device'])
                                if rnn_args['has_cell']:
                                    cell = init_cell(
                                        num_layers=1,
                                        batch_size=batch_size,
                                        hidden_size=rnn_args['hidden_size'],
                                        num_dir=rnn_args['num_dir'],
                                        dvc=rnn_args['device'])
                                else:
                                    cell = None
                                if rnn_args['has_stack']:
                                    stack = init_stack(batch_size,
                                                       rnn_args['stack_width'],
                                                       rnn_args['stack_depth'],
                                                       dvc=rnn_args['device'])
                                else:
                                    stack = None
                                hidden_states.append((hidden, cell, stack))
                            # forward propagation
                            outputs = model([inputs] + hidden_states)
                            predictions = outputs[0]
                            predictions = predictions.permute(1, 0, -1)
                            predictions = predictions.contiguous().view(
                                -1, predictions.shape[-1])
                            labels = labels.contiguous().view(-1)

                            # calculate loss
                            loss = criterion(predictions, labels)

                        # metrics
                        eval_dict = {}
                        score = IreleasePretrain.evaluate(
                            eval_dict, predictions, labels)

                        # TBoard info
                        # tracker.track("%s/loss" % phase, loss.item(), tb_idx[phase].IncAndGet())
                        # tracker.track("%s/score" % phase, score, tb_idx[phase].i)
                        # for k in eval_dict:
                        #     tracker.track('{}/{}'.format(phase, k), eval_dict[k], tb_idx[phase].i)

                        # backward pass
                        loss.backward()
                        optimizer.step()

                        # for epoch stats
                        epoch_losses.append(loss.item())

                        # for sim data resource
                        train_scores_lst.append(score)
                        loss_lst.append(loss.item())

                        # for epoch stats
                        epoch_scores.append(score)

                        print("\t{}: Epoch={}/{}, batch={}/{}, "
                              "pred_loss={:.4f}, accuracy: {:.2f}, sample: {}".
                              format(
                                  time_since(start),
                                  epoch + 1, n_epochs, b + 1, num_batches,
                                  loss.item(), eval_dict['accuracy'],
                                  generate_smiles(generator=model,
                                                  gen_data=gen_data,
                                                  init_args=rnn_args,
                                                  num_samples=1)))
                    IreleasePretrain.save_model(
                        model,
                        './model_dir/',
                        name=f'irelease-pretrained_stack-rnn_gru_'
                        f'{date_label}_epoch_{epoch}')
                # End of mini=batch iterations.
        except RuntimeError as e:
            print(str(e))

        duration = time.time() - start
        print('\nModel training duration: {:.0f}m {:.0f}s'.format(
            duration // 60, duration % 60))
        return {
            'model': model,
            'score': round(np.mean(epoch_scores), 3),
            'epoch': n_epochs
        }
Example #4
0
    def train(model,
              optimizer,
              gen_data,
              init_args,
              n_iters=5000,
              sim_data_node=None,
              epoch_ckpt=(2, 4.0),
              tb_writer=None,
              is_hsearch=False):
        tb_writer = None  # tb_writer()
        start = time.time()
        best_model_wts = model.state_dict()
        best_score = -10000
        best_epoch = -1
        terminate_training = False
        e_avg = ExpAverage(.01)
        num_batches = math.ceil(gen_data.file_len / gen_data.batch_size)
        n_epochs = math.ceil(n_iters / num_batches)
        grad_stats = GradStats(model, beta=0.)

        # learning rate decay schedulers
        # scheduler = sch.StepLR(optimizer, step_size=500, gamma=0.01)

        # pred_loss functions
        criterion = nn.CrossEntropyLoss(
            ignore_index=gen_data.char2idx[gen_data.pad_symbol])
        # criterion = LabelSmoothing(gen_data.n_characters, gen_data.char2idx[gen_data.pad_symbol], 0.1)

        # sub-nodes of sim data resource
        loss_lst = []
        train_loss_node = DataNode(label="train_loss", data=loss_lst)
        metrics_dict = {}
        metrics_node = DataNode(label="validation_metrics", data=metrics_dict)
        train_scores_lst = []
        train_scores_node = DataNode(label="train_score",
                                     data=train_scores_lst)
        scores_lst = []
        scores_node = DataNode(label="validation_score", data=scores_lst)

        # add sim data nodes to parent node
        if sim_data_node:
            sim_data_node.data = [
                train_loss_node, train_scores_node, metrics_node, scores_node
            ]

        try:
            # Main training loop
            tb_idx = {'train': Count(), 'val': Count(), 'test': Count()}
            for epoch in range(n_epochs):
                if terminate_training:
                    print("Terminating training...")
                    break
                for phase in ["train"]:  # , "val" if is_hsearch else "test"]:
                    if phase == "train":
                        print("Training....")
                        # Training mode
                        model.train()
                    else:
                        print("Validation...")
                        # Evaluation mode
                        model.eval()

                    epoch_losses = []
                    epoch_scores = []

                    # Iterate through mini-batches
                    # with TBMeanTracker(tb_writer, 10) as tracker:
                    with grad_stats:
                        for b in trange(0,
                                        num_batches,
                                        desc=f'{phase} in progress...'):
                            inputs, labels = gen_data.random_training_set()

                            optimizer.zero_grad()

                            # track history if only in train
                            with torch.set_grad_enabled(phase == "train"):
                                # forward propagation
                                stack = init_stack_2d(inputs.shape[0],
                                                      inputs.shape[1],
                                                      init_args['stack_depth'],
                                                      init_args['stack_width'],
                                                      dvc=init_args['device'])
                                predictions = model([inputs, stack])
                                predictions = predictions.permute(1, 0, -1)
                                predictions = predictions.contiguous().view(
                                    -1, predictions.shape[-1])
                                labels = labels.contiguous().view(-1)

                                # calculate loss
                                loss = criterion(predictions, labels)

                            # fail fast
                            if str(loss.item()) == "nan":
                                terminate_training = True
                                break

                            # metrics
                            eval_dict = {}
                            score = GpmtPretrain.evaluate(
                                eval_dict, predictions, labels)

                            # TBoard info
                            # tracker.track("%s/loss" % phase, loss.item(), tb_idx[phase].IncAndGet())
                            # tracker.track("%s/score" % phase, score, tb_idx[phase].i)
                            # for k in eval_dict:
                            #     tracker.track('{}/{}'.format(phase, k), eval_dict[k], tb_idx[phase].i)

                            if phase == "train":
                                # backward pass
                                loss.backward()
                                optimizer.step()

                                # for epoch stats
                                epoch_losses.append(loss.item())

                                # for sim data resource
                                train_scores_lst.append(score)
                                loss_lst.append(loss.item())

                                print(
                                    "\t{}: Epoch={}/{}, batch={}/{}, "
                                    "pred_loss={:.4f}, accuracy: {:.2f}, sample: {}"
                                    .format(
                                        time_since(start), epoch + 1,
                                        n_epochs, b + 1, num_batches,
                                        loss.item(), eval_dict['accuracy'],
                                        generate_smiles(generator=model,
                                                        gen_data=gen_data,
                                                        init_args=init_args,
                                                        num_samples=1,
                                                        gen_type='trans')))
                            else:
                                # for epoch stats
                                epoch_scores.append(score)

                                # for sim data resource
                                scores_lst.append(score)
                                for m in eval_dict:
                                    if m in metrics_dict:
                                        metrics_dict[m].append(eval_dict[m])
                                    else:
                                        metrics_dict[m] = [eval_dict[m]]

                                print("\nEpoch={}/{}, batch={}/{}, "
                                      "evaluation results= {}, accuracy={}".
                                      format(epoch + 1, n_epochs, b + 1,
                                             num_batches, eval_dict, score))
                    # End of mini=batch iterations.

                    if phase == "train":
                        ep_loss = np.nanmean(epoch_losses)
                        e_avg.update(ep_loss)
                        if epoch % (epoch_ckpt[0] - 1) == 0 and epoch > 0:
                            if e_avg.value > epoch_ckpt[1]:
                                terminate_training = True
                        print(
                            "\nPhase: {}, avg task pred_loss={:.4f}, ".format(
                                phase, np.nanmean(epoch_losses)))
                        # scheduler.step()
                    else:
                        mean_score = np.mean(epoch_scores)
                        if best_score < mean_score:
                            best_score = mean_score
                            best_model_wts = copy.deepcopy(model.state_dict())
                            best_epoch = epoch
        except RuntimeError as e:
            print(str(e))

        duration = time.time() - start
        print('\nModel training duration: {:.0f}m {:.0f}s'.format(
            duration // 60, duration % 60))
        try:
            model.load_state_dict(best_model_wts)
        except RuntimeError as e:
            print(str(e))
        return {'model': model, 'score': best_score, 'epoch': best_epoch}
Example #5
0
    def train(init_args, model_path=None, agent_net_name=None, reward_net_name=None, n_episodes=500,
              sim_data_node=None, tb_writer=None, is_hsearch=False, n_to_generate=200, learn_irl=True):
        tb_writer = tb_writer()
        agent = init_args['agent']
        probs_reg = init_args['probs_reg']
        drl_algorithm = init_args['drl_alg']
        irl_algorithm = init_args['irl_alg']
        reward_func = init_args['reward_func']
        gamma = init_args['gamma']
        episodes_to_train = init_args['episodes_to_train']
        expert_model = init_args['expert_model']
        demo_data_gen = init_args['demo_data_gen']
        unbiased_data_gen = init_args['unbiased_data_gen']
        best_model_wts = None
        exp_avg = ExpAverage(beta=0.6)
        best_score = -1.

        # load pretrained model
        if model_path and agent_net_name and reward_net_name:
            try:
                print('Loading pretrained model...')
                weights = IReLeaSE.load_model(model_path, agent_net_name)
                agent.model.load_state_dict(weights)
                print('Pretrained model loaded successfully!')
                reward_func.model.load_state_dict(IReLeaSE.load_model(model_path, reward_net_name))
                print('Reward model loaded successfully!')
            except:
                print('Pretrained model could not be loaded. Terminating prematurely.')
                return {'model': [drl_algorithm.actor,
                                  drl_algorithm.critic,
                                  irl_algorithm.model],
                        'score': round(best_score, 3),
                        'epoch': -1}

        start = time.time()

        # Begin simulation and training
        total_rewards = []
        trajectories = []
        done_episodes = 0
        batch_episodes = 0
        exp_trajectories = []
        step_idx = 0

        # collect mean predictions
        unbiased_smiles_mean_pred, biased_smiles_mean_pred, gen_smiles_mean_pred = [], [], []
        unbiased_smiles_mean_pred_data_node = DataNode('baseline_mean_vals', unbiased_smiles_mean_pred)
        biased_smiles_mean_pred_data_node = DataNode('biased_mean_vals', biased_smiles_mean_pred)
        gen_smiles_mean_pred_data_node = DataNode('gen_mean_vals', gen_smiles_mean_pred)
        if sim_data_node:
            sim_data_node.data = [unbiased_smiles_mean_pred_data_node,
                                  biased_smiles_mean_pred_data_node,
                                  gen_smiles_mean_pred_data_node]

        env = MoleculeEnv(actions=get_default_tokens(), reward_func=reward_func)
        exp_source = ExperienceSourceFirstLast(env, agent, gamma, steps_count=1, steps_delta=1)
        traj_prob = 1.
        exp_traj = []

        demo_score = np.mean(expert_model(demo_data_gen.random_training_set_smiles(1000))[1])
        baseline_score = np.mean(expert_model(unbiased_data_gen.random_training_set_smiles(1000))[1])
        # with contextlib.suppress(RuntimeError if is_hsearch else DummyException):
        try:
            with TBMeanTracker(tb_writer, 1) as tracker:
                for step_idx, exp in tqdm(enumerate(exp_source)):
                    exp_traj.append(exp)
                    traj_prob *= probs_reg.get(list(exp.state), exp.action)

                    if exp.last_state is None:
                        trajectories.append(Trajectory(terminal_state=EpisodeStep(exp.state, exp.action),
                                                       traj_prob=traj_prob))
                        exp_trajectories.append(exp_traj)  # for ExperienceFirstLast objects
                        exp_traj = []
                        traj_prob = 1.
                        probs_reg.clear()
                        batch_episodes += 1

                    new_rewards = exp_source.pop_total_rewards()
                    if new_rewards:
                        reward = new_rewards[0]
                        done_episodes += 1
                        total_rewards.append(reward)
                        mean_rewards = float(np.mean(total_rewards[-100:]))
                        tracker.track('mean_total_reward', mean_rewards, step_idx)
                        tracker.track('total_reward', reward, step_idx)
                        print(f'Time = {time_since(start)}, step = {step_idx}, reward = {reward:6.2f}, '
                              f'mean_100 = {mean_rewards:6.2f}, episodes = {done_episodes}')
                        with torch.set_grad_enabled(False):
                            samples = generate_smiles(drl_algorithm.model, demo_data_gen, init_args['gen_args'],
                                                      num_samples=n_to_generate)
                        predictions = expert_model(samples)[1]
                        mean_preds = np.nanmean(predictions)
                        if math.isnan(mean_preds) or math.isinf(mean_preds):
                            print(f'mean preds is {mean_preds}, terminating')
                            # best_score = -1.
                            break
                        try:
                            percentage_in_threshold = np.sum((predictions <= demo_score)) / len(predictions)
                        except:
                            percentage_in_threshold = 0.
                        per_valid = len(predictions) / n_to_generate
                        if per_valid < 0.2:
                            print(f'Percentage of valid SMILES is = {per_valid}. Terminating...')
                            # best_score = -1.
                            break
                        print(f'Mean value of predictions = {mean_preds}, % of valid SMILES = {per_valid}')
                        unbiased_smiles_mean_pred.append(float(baseline_score))
                        biased_smiles_mean_pred.append(float(demo_score))
                        gen_smiles_mean_pred.append(float(mean_preds))
                        tb_writer.add_scalars('qsar_score', {'sampled': mean_preds,
                                                             'baseline': baseline_score,
                                                             'demo_data': demo_score}, step_idx)
                        tb_writer.add_scalars('SMILES stats', {'per. of valid': per_valid,
                                                               'per. in drug-like region': percentage_in_threshold},
                                              step_idx)
                        eval_dict = {}
                        eval_score = IReLeaSE.evaluate(eval_dict, samples,
                                                       demo_data_gen.random_training_set_smiles(1000))
                        for k in eval_dict:
                            tracker.track(k, eval_dict[k], step_idx)
                        avg_len = np.nanmean([len(s) for s in samples])
                        tracker.track('Average SMILES length', np.nanmean([len(s) for s in samples]), step_idx)
                        d_penalty = eval_score < .5
                        s_penalty = avg_len < 20
                        diff = demo_score - mean_preds
                        # score = 3 * np.exp(diff) + np.log(per_valid + 1e-5) - s_penalty * np.exp(
                        #     diff) - d_penalty * np.exp(diff)
                        score = np.exp(diff)
                        # score = np.exp(diff) + np.mean([np.exp(per_valid), np.exp(percentage_in_threshold)])
                        if math.isnan(score) or math.isinf(score):
                            # best_score = -1.
                            print(f'Score is {score}, terminating.')
                            break
                        tracker.track('score', score, step_idx)
                        exp_avg.update(score)
                        if is_hsearch:
                            best_score = exp_avg.value
                        if exp_avg.value > best_score:
                            best_model_wts = [copy.deepcopy(drl_algorithm.actor.state_dict()),
                                              copy.deepcopy(drl_algorithm.critic.state_dict()),
                                              copy.deepcopy(irl_algorithm.model.state_dict())]
                            best_score = exp_avg.value
                        if best_score >= np.exp(0.):
                            print(f'threshold reached, best score={mean_preds}, '
                                  f'threshold={demo_score}, training completed')
                            break
                        if done_episodes == n_episodes:
                            print('Training completed!')
                            break

                    if batch_episodes < episodes_to_train:
                        continue

                    # Train models
                    print('Fitting models...')
                    irl_loss = 0.
                    # irl_loss = irl_algorithm.fit(trajectories) if learn_irl else 0.
                    rl_loss = drl_algorithm.fit(exp_trajectories)
                    samples = generate_smiles(drl_algorithm.model, demo_data_gen, init_args['gen_args'],
                                              num_samples=3)
                    print(f'IRL loss = {irl_loss}, RL loss = {rl_loss}, samples = {samples}')
                    tracker.track('irl_loss', irl_loss, step_idx)
                    tracker.track('critic_loss', rl_loss[0], step_idx)
                    tracker.track('agent_loss', rl_loss[1], step_idx)

                    # Reset
                    batch_episodes = 0
                    trajectories.clear()
                    exp_trajectories.clear()
        except Exception as e:
            print(str(e))
        if best_model_wts:
            drl_algorithm.actor.load_state_dict(best_model_wts[0])
            drl_algorithm.critic.load_state_dict(best_model_wts[1])
            irl_algorithm.model.load_state_dict(best_model_wts[2])
        duration = time.time() - start
        print('\nTraining duration: {:.0f}m {:.0f}s'.format(duration // 60, duration % 60))
        # if math.isinf(best_score) or math.isnan(best_score):
        #     best_score = -1.
        return {'model': [drl_algorithm.actor,
                          drl_algorithm.critic,
                          irl_algorithm.model],
                'score': round(best_score, 3),
                'epoch': done_episodes}
Example #6
0
    def train(generator,
              optimizer,
              rnn_args,
              pretrained_net_path=None,
              pretrained_net_name=None,
              n_iters=5000,
              sim_data_node=None,
              tb_writer=None,
              is_hsearch=False,
              is_pretraining=True,
              grad_clipping=5):
        expert_model = rnn_args['expert_model']
        tb_writer = tb_writer()
        best_model_wts = generator.state_dict()
        best_score = -1000
        best_epoch = -1
        demo_data_gen = rnn_args['demo_data_gen']
        unbiased_data_gen = rnn_args['unbiased_data_gen']
        prior_data_gen = rnn_args['prior_data_gen']
        score_exp_avg = ExpAverage(beta=0.6)
        exp_type = rnn_args['exp_type']

        if is_pretraining:
            num_batches = math.ceil(prior_data_gen.file_len /
                                    prior_data_gen.batch_size)
        else:
            num_batches = math.ceil(demo_data_gen.file_len /
                                    demo_data_gen.batch_size)
        n_epochs = math.ceil(n_iters / num_batches)
        grad_stats = GradStats(generator, beta=0.)

        # learning rate decay schedulers
        scheduler = sch.StepLR(optimizer, step_size=100, gamma=0.02)

        # pred_loss functions
        criterion = nn.CrossEntropyLoss(
            ignore_index=prior_data_gen.char2idx[prior_data_gen.pad_symbol])

        # sub-nodes of sim data resource
        loss_lst = []
        train_loss_node = DataNode(label="train_loss", data=loss_lst)

        # collect mean predictions
        unbiased_smiles_mean_pred, biased_smiles_mean_pred, gen_smiles_mean_pred = [], [], []
        unbiased_smiles_mean_pred_data_node = DataNode(
            'baseline_mean_vals', unbiased_smiles_mean_pred)
        biased_smiles_mean_pred_data_node = DataNode('biased_mean_vals',
                                                     biased_smiles_mean_pred)
        gen_smiles_mean_pred_data_node = DataNode('gen_mean_vals',
                                                  gen_smiles_mean_pred)
        if sim_data_node:
            sim_data_node.data = [
                train_loss_node, unbiased_smiles_mean_pred_data_node,
                biased_smiles_mean_pred_data_node,
                gen_smiles_mean_pred_data_node
            ]

        # load pretrained model
        if pretrained_net_path and pretrained_net_name:
            print('Loading pretrained model...')
            weights = RNNBaseline.load_model(pretrained_net_path,
                                             pretrained_net_name)
            generator.load_state_dict(weights)
            print('Pretrained model loaded successfully!')

        start = time.time()
        try:
            demo_score = np.mean(
                expert_model(
                    demo_data_gen.random_training_set_smiles(1000))[1])
            baseline_score = np.mean(
                expert_model(
                    unbiased_data_gen.random_training_set_smiles(1000))[1])
            step_idx = Count()
            gen_data = prior_data_gen if is_pretraining else demo_data_gen
            with TBMeanTracker(tb_writer, 1) as tracker:
                mode = 'Pretraining' if is_pretraining else 'Fine tuning'
                n_epochs = 30
                for epoch in range(n_epochs):
                    epoch_losses = []
                    epoch_mean_preds = []
                    epoch_per_valid = []
                    with grad_stats:
                        for b in trange(
                                0,
                                num_batches,
                                desc=
                                f'Epoch {epoch + 1}/{n_epochs}, {mode} in progress...'
                        ):
                            inputs, labels = gen_data.random_training_set()
                            optimizer.zero_grad()

                            predictions = generator(inputs)[0]
                            predictions = predictions.permute(1, 0, -1)
                            predictions = predictions.contiguous().view(
                                -1, predictions.shape[-1])
                            labels = labels.contiguous().view(-1)

                            # calculate loss
                            loss = criterion(predictions, labels)
                            epoch_losses.append(loss.item())

                            # backward pass
                            loss.backward()
                            if grad_clipping:
                                torch.nn.utils.clip_grad_norm_(
                                    generator.parameters(), grad_clipping)
                            optimizer.step()
                            # scheduler.step()

                            # for sim data resource
                            n_to_generate = 200
                            with torch.set_grad_enabled(False):
                                samples = generate_smiles(
                                    generator,
                                    demo_data_gen,
                                    rnn_args,
                                    num_samples=n_to_generate,
                                    max_len=smiles_max_len)
                            samples_pred = expert_model(samples)[1]

                            # metrics
                            eval_dict = {}
                            eval_score = RNNBaseline.evaluate(
                                eval_dict, samples,
                                demo_data_gen.random_training_set_smiles(1000))
                            # TBoard info
                            tracker.track('loss', loss.item(),
                                          step_idx.IncAndGet())
                            for k in eval_dict:
                                tracker.track(f'{k}', eval_dict[k], step_idx.i)
                            mean_preds = np.mean(samples_pred)
                            epoch_mean_preds.append(mean_preds)
                            per_valid = len(samples_pred) / n_to_generate
                            epoch_per_valid.append(per_valid)
                            if exp_type == 'drd2':
                                per_qualified = float(
                                    len([v for v in samples_pred if v >= 0.8
                                         ])) / len(samples_pred)
                                score = mean_preds
                            elif exp_type == 'logp':
                                per_qualified = np.sum(
                                    (samples_pred >= 1.0)
                                    & (samples_pred < 5.0)) / len(samples_pred)
                                score = mean_preds
                            elif exp_type == 'jak2_max':
                                per_qualified = np.sum(
                                    (samples_pred >=
                                     demo_score)) / len(samples_pred)
                                diff = mean_preds - demo_score
                                score = np.exp(diff)
                            elif exp_type == 'jak2_min':
                                per_qualified = np.sum(
                                    (samples_pred <=
                                     demo_score)) / len(samples_pred)
                                diff = demo_score - mean_preds
                                score = np.exp(diff)
                            else:  # pretraining
                                score = per_valid  # -loss.item()
                                per_qualified = 0.
                            unbiased_smiles_mean_pred.append(
                                float(baseline_score))
                            biased_smiles_mean_pred.append(float(demo_score))
                            gen_smiles_mean_pred.append(float(mean_preds))
                            tb_writer.add_scalars(
                                'qsar_score', {
                                    'sampled': mean_preds,
                                    'baseline': baseline_score,
                                    'demo_data': demo_score
                                }, step_idx.i)
                            tb_writer.add_scalars(
                                'SMILES stats', {
                                    'per. of valid': per_valid,
                                    'per. of qualified': per_qualified
                                }, step_idx.i)
                            avg_len = np.nanmean([len(s) for s in samples])
                            tracker.track('Average SMILES length', avg_len,
                                          step_idx.i)

                            score_exp_avg.update(score)
                            if score_exp_avg.value > best_score:
                                best_model_wts = copy.deepcopy(
                                    generator.state_dict())
                                best_score = score_exp_avg.value
                                best_epoch = epoch

                            if step_idx.i > 0 and step_idx.i % 1000 == 0:
                                smiles = generate_smiles(
                                    generator=generator,
                                    gen_data=gen_data,
                                    init_args=rnn_args,
                                    num_samples=3,
                                    max_len=smiles_max_len)
                                print(f'Sample SMILES = {smiles}')
                        # End of mini=batch iterations.
                        print(
                            f'{time_since(start)}: Epoch {epoch + 1}/{n_epochs}, loss={np.mean(epoch_losses)},'
                            f'Mean value of predictions = {np.mean(epoch_mean_preds)}, '
                            f'% of valid SMILES = {np.mean(epoch_per_valid)}')

        except RuntimeError as e:
            print(str(e))

        duration = time.time() - start
        print('Model training duration: {:.0f}m {:.0f}s'.format(
            duration // 60, duration % 60))
        generator.load_state_dict(best_model_wts)
        return {
            'model': generator,
            'score': round(best_score, 3),
            'epoch': best_epoch
        }