Ejemplo n.º 1
0
def _preprocess_states_actions(actions, states, device):
    # Process states and actions
    states = [''.join(list(state)) for state in states]
    states, states_len = pad_sequences(states)
    states, _ = seq2tensor(states, get_default_tokens())
    states = torch.from_numpy(states).long().to(device)
    states_len = torch.tensor(states_len).long().to(device)
    actions, _ = seq2tensor(actions, get_default_tokens())
    actions = torch.from_numpy(actions.reshape(-1)).long().to(device)
    return (states, states_len), actions
 def data_provider(k, flags):
     tokens = get_default_tokens()
     demo_data = GeneratorData(training_data_path=flags.demo_file,
                               delimiter='\t',
                               cols_to_read=[0],
                               keep_header=True,
                               pad_symbol=' ',
                               max_len=120,
                               tokens=tokens,
                               use_cuda=use_cuda)
     unbiased_data = GeneratorData(training_data_path=flags.unbiased_file,
                                   delimiter='\t',
                                   cols_to_read=[0],
                                   keep_header=True,
                                   pad_symbol=' ',
                                   max_len=120,
                                   tokens=tokens,
                                   use_cuda=use_cuda)
     prior_data = GeneratorData(training_data_path=flags.prior_data,
                                delimiter='\t',
                                cols_to_read=[0],
                                keep_header=True,
                                pad_symbol=' ',
                                max_len=120,
                                tokens=tokens,
                                use_cuda=use_cuda)
     return {
         'demo_data': demo_data,
         'unbiased_data': unbiased_data,
         'prior_data': prior_data
     }
Ejemplo n.º 3
0
    def initialize(hparams, train_data, val_data, test_data):
        # Create pytorch data loaders
        train_loader = DataLoader(SmilesDataset(train_data[0], train_data[1]),
                                  batch_size=hparams['batch'],
                                  collate_fn=lambda x: x)
        if val_data:
            val_loader = DataLoader(SmilesDataset(val_data[0], val_data[1]),
                                    batch_size=hparams['batch'],
                                    collate_fn=lambda x: x)
        else:
            val_loader = None
        test_loader = DataLoader(SmilesDataset(test_data[0], test_data[1]),
                                 batch_size=hparams['batch'],
                                 collate_fn=lambda x: x)
        # Create model and optimizer
        model = RNNPredictorModel(d_model=int(hparams['d_model']),
                                  tokens=get_default_tokens(),
                                  num_layers=int(hparams['rnn_num_layers']),
                                  dropout=float(hparams['dropout']),
                                  bidirectional=hparams['is_bidirectional'],
                                  unit_type=hparams['unit_type'],
                                  device=device).to(device)
        optimizer = parse_optimizer(hparams, model)
        metrics = [mean_squared_error, root_mean_squared_error, r2_score]

        return {
            'data_loaders': {
                'train': train_loader,
                'val': val_loader if val_data else None,
                'test': test_loader
            },
            'model': model,
            'optimizer': optimizer,
            'metrics': metrics
        }
Ejemplo n.º 4
0
 def __init__(self, hparams, device, is_binary=False):
     expert_model_dir = hparams['model_dir']
     assert (os.path.isdir(expert_model_dir)), 'Expert model(s) should be in a dedicated folder'
     self.models = []
     self.tokens = get_default_tokens()
     self.device = device
     model_paths = os.listdir(expert_model_dir)
     self.transformer = None
     self.is_binary = is_binary
     for model_file in model_paths:
         if 'transformer' in model_file:
             with open(os.path.join(expert_model_dir, model_file), 'rb') as f:
                 self.transformer = joblib.load(f)
                 continue
         model = RNNPredictorModel(d_model=hparams['d_model'],
                                   tokens=self.tokens,
                                   num_layers=hparams['rnn_num_layers'],
                                   dropout=hparams['dropout'],
                                   bidirectional=hparams['is_bidirectional'],
                                   unit_type=hparams['unit_type'],
                                   device=device).to(device)
         if is_binary:
             model = torch.nn.Sequential(model, torch.nn.Sigmoid()).to(device)
         model.load_state_dict(torch.load(os.path.join(expert_model_dir, model_file),
                                          map_location=torch.device(device)))
         model = model.eval()
         self.models.append(model)
Ejemplo n.º 5
0
    def initialize(hparams, train_data, val_data, test_data):
        # Create pytorch data loaders
        train_loader = DataLoader(SmilesDataset(train_data[0], train_data[1]),
                                  batch_size=hparams['batch'],
                                  shuffle=True,
                                  collate_fn=lambda x: x)
        if val_data:
            val_loader = DataLoader(SmilesDataset(val_data[0], val_data[1]),
                                    batch_size=hparams['batch'],
                                    collate_fn=lambda x: x)
        else:
            val_loader = None
        test_loader = DataLoader(SmilesDataset(test_data[0], test_data[1]),
                                 batch_size=hparams['batch'],
                                 collate_fn=lambda x: x)
        # Create model and optimizer
        model = torch.nn.Sequential(RNNPredictorModel(d_model=int(hparams['d_model']),
                                                      tokens=get_default_tokens(),
                                                      num_layers=int(hparams['rnn_num_layers']),
                                                      dropout=float(hparams['dropout']),
                                                      bidirectional=hparams['is_bidirectional'],
                                                      unit_type=hparams['unit_type'],
                                                      device=device),
                                    torch.nn.Sigmoid()).to(device)
        optimizer = parse_optimizer(hparams, model)
        metrics = [accuracy_score, precision_score, recall_score, f1_score]

        return {'data_loaders': {'train': train_loader,
                                 'val': val_loader if val_data else None,
                                 'test': test_loader},
                'model': model,
                'optimizer': optimizer,
                'metrics': metrics}
Ejemplo n.º 6
0
    def calc_adv_ref(self, trajectory):
        states, actions, _ = unpack_batch([trajectory], self.gamma)
        last_state = ''.join(list(states[-1]))
        inp, _ = seq2tensor([last_state], tokens=get_default_tokens())
        inp = torch.from_numpy(inp).long().to(self.device)
        values_v = self.critic(inp)
        values = values_v.view(-1, ).data.cpu().numpy()
        last_gae = 0.0
        result_adv = []
        result_ref = []
        for val, next_val, exp in zip(reversed(values[:-1]),
                                      reversed(values[1:]),
                                      reversed(trajectory[:-1])):
            if exp.last_state is None:  # for terminal state
                delta = exp.reward - val
                last_gae = delta
            else:
                delta = exp.reward + self.gamma * next_val - val
                last_gae = delta + self.gamma * self.gae_lambda * last_gae
            result_adv.append(last_gae)
            result_ref.append(last_gae + val)

        adv_v = torch.FloatTensor(list(reversed(result_adv))).to(self.device)
        ref_v = torch.FloatTensor(list(reversed(result_ref))).to(self.device)
        return states[:-1], actions[:-1], adv_v, ref_v
Ejemplo n.º 7
0
def smiles_to_tensor(smiles):
    smiles = list(smiles)
    _, valid_vec = canonical_smiles(smiles)
    valid_vec = torch.tensor(valid_vec).view(-1, 1).float().to(device)
    smiles, _ = pad_sequences(smiles)
    inp, _ = seq2tensor(smiles, tokens=get_default_tokens())
    inp = torch.from_numpy(inp).long().to(device)
    return inp, valid_vec
Ejemplo n.º 8
0
 def data_provider(k, flags):
     tokens = get_default_tokens()
     gen_data = GeneratorData(training_data_path=flags.data_file,
                              delimiter='\t',
                              cols_to_read=[0],
                              keep_header=True,
                              pad_symbol=' ',
                              max_len=120,
                              tokens=tokens,
                              use_cuda=use_cuda)
     return {"train": gen_data, "val": gen_data, "test": gen_data}
Ejemplo n.º 9
0
    def fit(self, trajectories):
        """Train the reward function / model using the GRL algorithm."""
        """Train the reward function / model using the GRL algorithm."""
        if self.use_buffer:
            extra_trajs = self.replay_buffer.sample(self.batch_size)
            trajectories.extend(extra_trajs)
            self.replay_buffer.populate(trajectories)
        d_traj, d_traj_probs = [], []
        for traj in trajectories:
            d_traj.append(''.join(list(traj.terminal_state.state)) +
                          traj.terminal_state.action)
            d_traj_probs.append(traj.traj_prob)
        _, valid_vec_samp = canonical_smiles(d_traj)
        valid_vec_samp = torch.tensor(valid_vec_samp).view(-1, 1).float().to(
            self.device)
        d_traj, _ = pad_sequences(d_traj)
        d_samp, _ = seq2tensor(d_traj, tokens=get_default_tokens())
        d_samp = torch.from_numpy(d_samp).long().to(self.device)
        losses = []
        for i in trange(self.k, desc='IRL optimization...'):
            # D_demo processing
            demo_states, demo_actions = self.demo_gen_data.random_training_set(
            )
            d_demo = torch.cat(
                [demo_states, demo_actions[:, -1].reshape(-1, 1)],
                dim=1).to(self.device)
            valid_vec_demo = torch.ones(d_demo.shape[0]).view(
                -1, 1).float().to(self.device)
            d_demo_out = self.model([d_demo, valid_vec_demo])

            # D_samp processing
            d_samp_out = self.model([d_samp, valid_vec_samp])
            d_out_combined = torch.cat([d_samp_out, d_demo_out], dim=0)
            if d_samp_out.shape[0] < 1000:
                d_samp_out = torch.cat([d_samp_out, d_demo_out], dim=0)
            z = torch.ones(d_samp_out.shape[0]).float().to(
                self.device)  # dummy importance weights TODO: replace this
            d_samp_out = z.view(-1, 1) * torch.exp(d_samp_out)

            # objective
            loss = torch.mean(d_demo_out) - torch.log(torch.mean(d_samp_out))
            losses.append(loss.item())
            loss = -loss  # for maximization

            # update params
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            # self.lr_sch.step()
        return np.mean(losses)
    def train(init_args,
              agent_net_path=None,
              agent_net_name=None,
              seed=0,
              n_episodes=500,
              sim_data_node=None,
              tb_writer=None,
              is_hsearch=False,
              n_to_generate=200,
              learn_irl=True,
              bias_mode='max'):
        tb_writer = tb_writer()
        agent = init_args['agent']
        probs_reg = init_args['probs_reg']
        drl_algorithm = init_args['drl_alg']
        irl_algorithm = init_args['irl_alg']
        reward_func = init_args['reward_func']
        gamma = init_args['gamma']
        episodes_to_train = init_args['episodes_to_train']
        expert_model = init_args['expert_model']
        demo_data_gen = init_args['demo_data_gen']
        unbiased_data_gen = init_args['unbiased_data_gen']
        best_model_wts = None
        best_score = 0.
        exp_avg = ExpAverage(beta=0.6)

        # load pretrained model
        if agent_net_path and agent_net_name:
            print('Loading pretrained model...')
            agent.model.load_state_dict(
                IReLeaSE.load_model(agent_net_path, agent_net_name))
            print('Pretrained model loaded successfully!')

        # collect mean predictions
        unbiased_smiles_mean_pred, biased_smiles_mean_pred, gen_smiles_mean_pred = [], [], []
        unbiased_smiles_mean_pred_data_node = DataNode(
            'baseline_mean_vals', unbiased_smiles_mean_pred)
        biased_smiles_mean_pred_data_node = DataNode('biased_mean_vals',
                                                     biased_smiles_mean_pred)
        gen_smiles_mean_pred_data_node = DataNode('gen_mean_vals',
                                                  gen_smiles_mean_pred)
        if sim_data_node:
            sim_data_node.data = [
                unbiased_smiles_mean_pred_data_node,
                biased_smiles_mean_pred_data_node,
                gen_smiles_mean_pred_data_node
            ]

        start = time.time()

        # Begin simulation and training
        total_rewards = []
        irl_trajectories = []
        done_episodes = 0
        batch_episodes = 0
        exp_trajectories = []

        env = MoleculeEnv(actions=get_default_tokens(),
                          reward_func=reward_func)
        exp_source = ExperienceSourceFirstLast(env,
                                               agent,
                                               gamma,
                                               steps_count=1,
                                               steps_delta=1)
        traj_prob = 1.
        exp_traj = []

        demo_score = np.mean(
            expert_model(demo_data_gen.random_training_set_smiles(1000))[1])
        baseline_score = np.mean(
            expert_model(
                unbiased_data_gen.random_training_set_smiles(1000))[1])
        with contextlib.suppress(Exception if is_hsearch else DummyException):
            with TBMeanTracker(tb_writer, 1) as tracker:
                for step_idx, exp in tqdm(enumerate(exp_source)):
                    exp_traj.append(exp)
                    traj_prob *= probs_reg.get(list(exp.state), exp.action)

                    if exp.last_state is None:
                        irl_trajectories.append(
                            Trajectory(terminal_state=EpisodeStep(
                                exp.state, exp.action),
                                       traj_prob=traj_prob))
                        exp_trajectories.append(
                            exp_traj)  # for ExperienceFirstLast objects
                        exp_traj = []
                        traj_prob = 1.
                        probs_reg.clear()
                        batch_episodes += 1

                    new_rewards = exp_source.pop_total_rewards()
                    if new_rewards:
                        reward = new_rewards[0]
                        done_episodes += 1
                        total_rewards.append(reward)
                        mean_rewards = float(np.mean(total_rewards[-100:]))
                        tracker.track('mean_total_reward', mean_rewards,
                                      step_idx)
                        tracker.track('total_reward', reward, step_idx)
                        print(
                            f'Time = {time_since(start)}, step = {step_idx}, reward = {reward:6.2f}, '
                            f'mean_100 = {mean_rewards:6.2f}, episodes = {done_episodes}'
                        )
                        with torch.set_grad_enabled(False):
                            samples = generate_smiles(
                                drl_algorithm.model,
                                demo_data_gen,
                                init_args['gen_args'],
                                num_samples=n_to_generate)
                        predictions = expert_model(samples)[1]
                        mean_preds = np.mean(predictions)
                        try:
                            percentage_in_threshold = np.sum(
                                (predictions >= 7.0)) / len(predictions)
                        except:
                            percentage_in_threshold = 0.
                        per_valid = len(predictions) / n_to_generate
                        print(
                            f'Mean value of predictions = {mean_preds}, '
                            f'% of valid SMILES = {per_valid}, '
                            f'% in drug-like region={percentage_in_threshold}')
                        unbiased_smiles_mean_pred.append(float(baseline_score))
                        biased_smiles_mean_pred.append(float(demo_score))
                        gen_smiles_mean_pred.append(float(mean_preds))
                        tb_writer.add_scalars(
                            'qsar_score', {
                                'sampled': mean_preds,
                                'baseline': baseline_score,
                                'demo_data': demo_score
                            }, step_idx)
                        tb_writer.add_scalars(
                            'SMILES stats', {
                                'per. of valid': per_valid,
                                'per. above threshold': percentage_in_threshold
                            }, step_idx)
                        eval_dict = {}
                        eval_score = IReLeaSE.evaluate(
                            eval_dict, samples,
                            demo_data_gen.random_training_set_smiles(1000))

                        for k in eval_dict:
                            tracker.track(k, eval_dict[k], step_idx)
                        tracker.track('Average SMILES length',
                                      np.nanmean([len(s) for s in samples]),
                                      step_idx)
                        if bias_mode == 'max':
                            diff = mean_preds - demo_score
                        else:
                            diff = demo_score - mean_preds
                        score = np.exp(diff)
                        exp_avg.update(score)
                        tracker.track('score', score, step_idx)
                        if exp_avg.value > best_score:
                            best_model_wts = [
                                copy.deepcopy(
                                    drl_algorithm.model.state_dict()),
                                copy.deepcopy(irl_algorithm.model.state_dict())
                            ]
                            best_score = exp_avg.value
                        if best_score >= np.exp(0.):
                            print(
                                f'threshold reached, best score={mean_preds}, '
                                f'threshold={demo_score}, training completed')
                            break
                        if done_episodes == n_episodes:
                            print('Training completed!')
                            break

                    if batch_episodes < episodes_to_train:
                        continue

                    # Train models
                    print('Fitting models...')
                    irl_stmt = ''
                    if learn_irl:
                        irl_loss = irl_algorithm.fit(irl_trajectories)
                        tracker.track('irl_loss', irl_loss, step_idx)
                        irl_stmt = f'IRL loss = {irl_loss}, '
                    rl_loss = drl_algorithm.fit(exp_trajectories)
                    samples = generate_smiles(drl_algorithm.model,
                                              demo_data_gen,
                                              init_args['gen_args'],
                                              num_samples=3)
                    print(
                        f'{irl_stmt}RL loss = {rl_loss}, samples = {samples}')
                    tracker.track('agent_loss', rl_loss, step_idx)

                    # Reset
                    batch_episodes = 0
                    irl_trajectories.clear()
                    exp_trajectories.clear()

        if best_model_wts:
            drl_algorithm.model.load_state_dict(best_model_wts[0])
            irl_algorithm.model.load_state_dict(best_model_wts[1])
        duration = time.time() - start
        print('\nTraining duration: {:.0f}m {:.0f}s'.format(
            duration // 60, duration % 60))
        return {
            'model': [drl_algorithm.model, irl_algorithm.model],
            'score': round(best_score, 3),
            'epoch': done_episodes
        }
                        default='drd2_active.smi',
                        help='The filename for the created dataset')
    args = parser.parse_args()

    assert (os.path.exists(args.svc))
    assert (os.path.exists(args.data))
    assert (0 < args.threshold < 1)

    # Load file containing SMILES
    gen_data = GeneratorData(training_data_path=args.data,
                             delimiter='\t',
                             cols_to_read=[0],
                             keep_header=True,
                             pad_symbol=' ',
                             max_len=120,
                             tokens=get_default_tokens(),
                             use_cuda=False)

    # Load classifier
    clf = DRD2Model(args.svc)

    # Screen SMILES in data file and write active compounds to file.
    os.makedirs(args.save_dir, exist_ok=True)
    num_active = 0
    with open(os.path.join(args.save_dir, args.filename), 'w') as f:
        for i in trange(gen_data.file_len, desc='Screening compounds...'):
            smiles = gen_data.file[i][1:-1]
            p = clf(smiles)
            if p >= args.threshold:
                f.write(smiles + '\n')
                num_active += 1
Ejemplo n.º 12
0
    def train(init_args, model_path=None, agent_net_name=None, reward_net_name=None, n_episodes=500,
              sim_data_node=None, tb_writer=None, is_hsearch=False, n_to_generate=200, learn_irl=True):
        tb_writer = tb_writer()
        agent = init_args['agent']
        probs_reg = init_args['probs_reg']
        drl_algorithm = init_args['drl_alg']
        irl_algorithm = init_args['irl_alg']
        reward_func = init_args['reward_func']
        gamma = init_args['gamma']
        episodes_to_train = init_args['episodes_to_train']
        expert_model = init_args['expert_model']
        demo_data_gen = init_args['demo_data_gen']
        unbiased_data_gen = init_args['unbiased_data_gen']
        best_model_wts = None
        exp_avg = ExpAverage(beta=0.6)
        best_score = -1.

        # load pretrained model
        if model_path and agent_net_name and reward_net_name:
            try:
                print('Loading pretrained model...')
                weights = IReLeaSE.load_model(model_path, agent_net_name)
                agent.model.load_state_dict(weights)
                print('Pretrained model loaded successfully!')
                reward_func.model.load_state_dict(IReLeaSE.load_model(model_path, reward_net_name))
                print('Reward model loaded successfully!')
            except:
                print('Pretrained model could not be loaded. Terminating prematurely.')
                return {'model': [drl_algorithm.actor,
                                  drl_algorithm.critic,
                                  irl_algorithm.model],
                        'score': round(best_score, 3),
                        'epoch': -1}

        start = time.time()

        # Begin simulation and training
        total_rewards = []
        trajectories = []
        done_episodes = 0
        batch_episodes = 0
        exp_trajectories = []
        step_idx = 0

        # collect mean predictions
        unbiased_smiles_mean_pred, biased_smiles_mean_pred, gen_smiles_mean_pred = [], [], []
        unbiased_smiles_mean_pred_data_node = DataNode('baseline_mean_vals', unbiased_smiles_mean_pred)
        biased_smiles_mean_pred_data_node = DataNode('biased_mean_vals', biased_smiles_mean_pred)
        gen_smiles_mean_pred_data_node = DataNode('gen_mean_vals', gen_smiles_mean_pred)
        if sim_data_node:
            sim_data_node.data = [unbiased_smiles_mean_pred_data_node,
                                  biased_smiles_mean_pred_data_node,
                                  gen_smiles_mean_pred_data_node]

        env = MoleculeEnv(actions=get_default_tokens(), reward_func=reward_func)
        exp_source = ExperienceSourceFirstLast(env, agent, gamma, steps_count=1, steps_delta=1)
        traj_prob = 1.
        exp_traj = []

        demo_score = np.mean(expert_model(demo_data_gen.random_training_set_smiles(1000))[1])
        baseline_score = np.mean(expert_model(unbiased_data_gen.random_training_set_smiles(1000))[1])
        # with contextlib.suppress(RuntimeError if is_hsearch else DummyException):
        try:
            with TBMeanTracker(tb_writer, 1) as tracker:
                for step_idx, exp in tqdm(enumerate(exp_source)):
                    exp_traj.append(exp)
                    traj_prob *= probs_reg.get(list(exp.state), exp.action)

                    if exp.last_state is None:
                        trajectories.append(Trajectory(terminal_state=EpisodeStep(exp.state, exp.action),
                                                       traj_prob=traj_prob))
                        exp_trajectories.append(exp_traj)  # for ExperienceFirstLast objects
                        exp_traj = []
                        traj_prob = 1.
                        probs_reg.clear()
                        batch_episodes += 1

                    new_rewards = exp_source.pop_total_rewards()
                    if new_rewards:
                        reward = new_rewards[0]
                        done_episodes += 1
                        total_rewards.append(reward)
                        mean_rewards = float(np.mean(total_rewards[-100:]))
                        tracker.track('mean_total_reward', mean_rewards, step_idx)
                        tracker.track('total_reward', reward, step_idx)
                        print(f'Time = {time_since(start)}, step = {step_idx}, reward = {reward:6.2f}, '
                              f'mean_100 = {mean_rewards:6.2f}, episodes = {done_episodes}')
                        with torch.set_grad_enabled(False):
                            samples = generate_smiles(drl_algorithm.model, demo_data_gen, init_args['gen_args'],
                                                      num_samples=n_to_generate)
                        predictions = expert_model(samples)[1]
                        mean_preds = np.nanmean(predictions)
                        if math.isnan(mean_preds) or math.isinf(mean_preds):
                            print(f'mean preds is {mean_preds}, terminating')
                            # best_score = -1.
                            break
                        try:
                            percentage_in_threshold = np.sum((predictions <= demo_score)) / len(predictions)
                        except:
                            percentage_in_threshold = 0.
                        per_valid = len(predictions) / n_to_generate
                        if per_valid < 0.2:
                            print(f'Percentage of valid SMILES is = {per_valid}. Terminating...')
                            # best_score = -1.
                            break
                        print(f'Mean value of predictions = {mean_preds}, % of valid SMILES = {per_valid}')
                        unbiased_smiles_mean_pred.append(float(baseline_score))
                        biased_smiles_mean_pred.append(float(demo_score))
                        gen_smiles_mean_pred.append(float(mean_preds))
                        tb_writer.add_scalars('qsar_score', {'sampled': mean_preds,
                                                             'baseline': baseline_score,
                                                             'demo_data': demo_score}, step_idx)
                        tb_writer.add_scalars('SMILES stats', {'per. of valid': per_valid,
                                                               'per. in drug-like region': percentage_in_threshold},
                                              step_idx)
                        eval_dict = {}
                        eval_score = IReLeaSE.evaluate(eval_dict, samples,
                                                       demo_data_gen.random_training_set_smiles(1000))
                        for k in eval_dict:
                            tracker.track(k, eval_dict[k], step_idx)
                        avg_len = np.nanmean([len(s) for s in samples])
                        tracker.track('Average SMILES length', np.nanmean([len(s) for s in samples]), step_idx)
                        d_penalty = eval_score < .5
                        s_penalty = avg_len < 20
                        diff = demo_score - mean_preds
                        # score = 3 * np.exp(diff) + np.log(per_valid + 1e-5) - s_penalty * np.exp(
                        #     diff) - d_penalty * np.exp(diff)
                        score = np.exp(diff)
                        # score = np.exp(diff) + np.mean([np.exp(per_valid), np.exp(percentage_in_threshold)])
                        if math.isnan(score) or math.isinf(score):
                            # best_score = -1.
                            print(f'Score is {score}, terminating.')
                            break
                        tracker.track('score', score, step_idx)
                        exp_avg.update(score)
                        if is_hsearch:
                            best_score = exp_avg.value
                        if exp_avg.value > best_score:
                            best_model_wts = [copy.deepcopy(drl_algorithm.actor.state_dict()),
                                              copy.deepcopy(drl_algorithm.critic.state_dict()),
                                              copy.deepcopy(irl_algorithm.model.state_dict())]
                            best_score = exp_avg.value
                        if best_score >= np.exp(0.):
                            print(f'threshold reached, best score={mean_preds}, '
                                  f'threshold={demo_score}, training completed')
                            break
                        if done_episodes == n_episodes:
                            print('Training completed!')
                            break

                    if batch_episodes < episodes_to_train:
                        continue

                    # Train models
                    print('Fitting models...')
                    irl_loss = 0.
                    # irl_loss = irl_algorithm.fit(trajectories) if learn_irl else 0.
                    rl_loss = drl_algorithm.fit(exp_trajectories)
                    samples = generate_smiles(drl_algorithm.model, demo_data_gen, init_args['gen_args'],
                                              num_samples=3)
                    print(f'IRL loss = {irl_loss}, RL loss = {rl_loss}, samples = {samples}')
                    tracker.track('irl_loss', irl_loss, step_idx)
                    tracker.track('critic_loss', rl_loss[0], step_idx)
                    tracker.track('agent_loss', rl_loss[1], step_idx)

                    # Reset
                    batch_episodes = 0
                    trajectories.clear()
                    exp_trajectories.clear()
        except Exception as e:
            print(str(e))
        if best_model_wts:
            drl_algorithm.actor.load_state_dict(best_model_wts[0])
            drl_algorithm.critic.load_state_dict(best_model_wts[1])
            irl_algorithm.model.load_state_dict(best_model_wts[2])
        duration = time.time() - start
        print('\nTraining duration: {:.0f}m {:.0f}s'.format(duration // 60, duration % 60))
        # if math.isinf(best_score) or math.isnan(best_score):
        #     best_score = -1.
        return {'model': [drl_algorithm.actor,
                          drl_algorithm.critic,
                          irl_algorithm.model],
                'score': round(best_score, 3),
                'epoch': done_episodes}
Ejemplo n.º 13
0
    def fit(self, trajectories):
        sq2ten = lambda x: torch.from_numpy(
            seq2tensor(x, get_default_tokens())[0]).long().to(self.device)
        t_states, t_actions, t_adv, t_ref = [], [], [], []
        t_old_probs = []
        for traj in trajectories:
            states, actions, adv_v, ref_v = self.calc_adv_ref(traj)
            if len(states) == 0:
                continue
            t_states.append(states)
            t_actions.append(actions)
            t_adv.append(adv_v)
            t_ref.append(ref_v)

            with torch.set_grad_enabled(False):
                hidden_states = self.initial_states_func(
                    batch_size=1, **self.initial_states_args)
                trajectory_input = sq2ten(states[-1])
                actions = sq2ten(actions)
                old_probs = []
                for p in range(len(trajectory_input)):
                    outputs = self.model([trajectory_input[p].reshape(1, 1)] +
                                         hidden_states)
                    output, hidden_states = outputs[0], outputs[1:]
                    log_prob = torch.log_softmax(output.view(1, -1), dim=1)
                    old_probs.append(log_prob[0, actions[p]].item())
                t_old_probs.append(old_probs)

        if len(t_states) == 0:
            return 0., 0.

        for epoch in trange(self.ppo_epochs, desc='PPO optimization...'):
            cr_loss = 0.
            ac_loss = 0.
            for i in range(len(t_states)):
                traj_last_state = t_states[i][-1]
                traj_actions = t_actions[i]
                traj_adv = t_adv[i]
                traj_ref = t_ref[i]
                traj_old_probs = t_old_probs[i]
                hidden_states = self.initial_states_func(
                    1, **self.initial_states_args)
                for p in range(len(traj_last_state)):
                    state, action, adv = traj_last_state[p], traj_actions[
                        p], traj_adv[p]
                    old_log_prob = traj_old_probs[p]
                    state, action = sq2ten(state), sq2ten(action)

                    # Critic
                    pred = self.critic(state)
                    cr_loss = cr_loss + F.mse_loss(pred.reshape(-1, 1),
                                                   traj_ref[p].reshape(-1, 1))

                    # Actor
                    outputs = self.actor([state] + hidden_states)
                    output, hidden_states = outputs[0], outputs[1:]
                    logprob_pi_v = torch.log_softmax(output.view(1, -1),
                                                     dim=-1)
                    logprob_pi_v = logprob_pi_v[0, action]
                    ratio_v = torch.exp(logprob_pi_v - old_log_prob)
                    surr_obj_v = adv * ratio_v
                    clipped_surr_v = adv * torch.clamp(
                        ratio_v, 1.0 - self.ppo_eps, 1.0 + self.ppo_eps)
                    loss_policy_v = torch.min(surr_obj_v, clipped_surr_v)

                    # Maximize entropy
                    prob = torch.softmax(output.view(1, -1), dim=1)
                    prob = prob[0, action]
                    entropy = prob * logprob_pi_v
                    entropy_loss = self.entropy_beta * entropy
                    ac_loss = ac_loss - (loss_policy_v + entropy_loss)
            # Update weights
            self.critic_opt.zero_grad()
            self.actor_opt.zero_grad()
            cr_loss = cr_loss / len(trajectories)
            ac_loss = ac_loss / len(trajectories)
            cr_loss.backward()
            ac_loss.backward()
            self.critic_opt.step()
            self.actor_opt.step()
        return cr_loss.item(), -ac_loss.item()