Ejemplo n.º 1
0
    def test_simple(self):
        writer = TestTBWriter()
        with TBMeanTracker(writer, batch_size=4) as tracker:
            tracker.track("param_1", value=10, iter_index=1)
            self.assertEquals(writer.buffer, [])
            tracker.track("param_1", value=10, iter_index=2)
            self.assertEquals(writer.buffer, [])
            tracker.track("param_1", value=10, iter_index=3)
            self.assertEquals(writer.buffer, [])
            tracker.track("param_1", value=10, iter_index=4)
            self.assertEquals(writer.buffer, [("param_1", 10.0, 4)])
            writer.reset()

            tracker.track("param_1", value=1.0, iter_index=1)
            self.assertEquals(writer.buffer, [])
            tracker.track("param_1", value=2.0, iter_index=2)
            self.assertEquals(writer.buffer, [])
            tracker.track("param_1", value=-3.0, iter_index=3)
            self.assertEquals(writer.buffer, [])
            tracker.track("param_1", value=1.0, iter_index=4)
            self.assertEquals(writer.buffer, [("param_1", (1.0 + 2.0 - 3.0 + 1.0) / 4, 4)])
            writer.reset()


        with self.assertRaises(AssertionError):
            writer.add_scalar("bla", 1.0, 10)
        self.assertTrue(writer.closed)
Ejemplo n.º 2
0
    def test_tensor_large(self):
        writer = TestTBWriter()

        with TBMeanTracker(writer, batch_size=2) as tracker:
            t = torch.LongTensor([1, 2, 3])
            tracker.track("p1", t, iter_index=1)
            tracker.track("p1", t, iter_index=2)
            self.assertEquals(writer.buffer, [("p1", 2.0, 2)])
Ejemplo n.º 3
0
 def test_as_float(self):
     self.assertAlmostEqual(1.0, TBMeanTracker._as_float(1.0))
     self.assertAlmostEqual(0.33333333333333, TBMeanTracker._as_float(1.0/3.0))
     self.assertAlmostEqual(2.0, TBMeanTracker._as_float(torch.LongTensor([1, 2, 3])))
     self.assertAlmostEqual(0.6666666666666, TBMeanTracker._as_float(torch.LongTensor([1, 1, 0])))
     self.assertAlmostEqual(1.0, TBMeanTracker._as_float(np.array([1.0, 1.0, 1.0])))
     self.assertAlmostEqual(1.0, TBMeanTracker._as_float(np.sqrt(np.array([1.0], dtype=np.float32)[0])))
    def train(init_args,
              agent_net_path=None,
              agent_net_name=None,
              seed=0,
              n_episodes=500,
              sim_data_node=None,
              tb_writer=None,
              is_hsearch=False,
              n_to_generate=200,
              learn_irl=True,
              bias_mode='max'):
        tb_writer = tb_writer()
        agent = init_args['agent']
        probs_reg = init_args['probs_reg']
        drl_algorithm = init_args['drl_alg']
        irl_algorithm = init_args['irl_alg']
        reward_func = init_args['reward_func']
        gamma = init_args['gamma']
        episodes_to_train = init_args['episodes_to_train']
        expert_model = init_args['expert_model']
        demo_data_gen = init_args['demo_data_gen']
        unbiased_data_gen = init_args['unbiased_data_gen']
        best_model_wts = None
        best_score = 0.
        exp_avg = ExpAverage(beta=0.6)

        # load pretrained model
        if agent_net_path and agent_net_name:
            print('Loading pretrained model...')
            agent.model.load_state_dict(
                IReLeaSE.load_model(agent_net_path, agent_net_name))
            print('Pretrained model loaded successfully!')

        # collect mean predictions
        unbiased_smiles_mean_pred, biased_smiles_mean_pred, gen_smiles_mean_pred = [], [], []
        unbiased_smiles_mean_pred_data_node = DataNode(
            'baseline_mean_vals', unbiased_smiles_mean_pred)
        biased_smiles_mean_pred_data_node = DataNode('biased_mean_vals',
                                                     biased_smiles_mean_pred)
        gen_smiles_mean_pred_data_node = DataNode('gen_mean_vals',
                                                  gen_smiles_mean_pred)
        if sim_data_node:
            sim_data_node.data = [
                unbiased_smiles_mean_pred_data_node,
                biased_smiles_mean_pred_data_node,
                gen_smiles_mean_pred_data_node
            ]

        start = time.time()

        # Begin simulation and training
        total_rewards = []
        irl_trajectories = []
        done_episodes = 0
        batch_episodes = 0
        exp_trajectories = []

        env = MoleculeEnv(actions=get_default_tokens(),
                          reward_func=reward_func)
        exp_source = ExperienceSourceFirstLast(env,
                                               agent,
                                               gamma,
                                               steps_count=1,
                                               steps_delta=1)
        traj_prob = 1.
        exp_traj = []

        demo_score = np.mean(
            expert_model(demo_data_gen.random_training_set_smiles(1000))[1])
        baseline_score = np.mean(
            expert_model(
                unbiased_data_gen.random_training_set_smiles(1000))[1])
        with contextlib.suppress(Exception if is_hsearch else DummyException):
            with TBMeanTracker(tb_writer, 1) as tracker:
                for step_idx, exp in tqdm(enumerate(exp_source)):
                    exp_traj.append(exp)
                    traj_prob *= probs_reg.get(list(exp.state), exp.action)

                    if exp.last_state is None:
                        irl_trajectories.append(
                            Trajectory(terminal_state=EpisodeStep(
                                exp.state, exp.action),
                                       traj_prob=traj_prob))
                        exp_trajectories.append(
                            exp_traj)  # for ExperienceFirstLast objects
                        exp_traj = []
                        traj_prob = 1.
                        probs_reg.clear()
                        batch_episodes += 1

                    new_rewards = exp_source.pop_total_rewards()
                    if new_rewards:
                        reward = new_rewards[0]
                        done_episodes += 1
                        total_rewards.append(reward)
                        mean_rewards = float(np.mean(total_rewards[-100:]))
                        tracker.track('mean_total_reward', mean_rewards,
                                      step_idx)
                        tracker.track('total_reward', reward, step_idx)
                        print(
                            f'Time = {time_since(start)}, step = {step_idx}, reward = {reward:6.2f}, '
                            f'mean_100 = {mean_rewards:6.2f}, episodes = {done_episodes}'
                        )
                        with torch.set_grad_enabled(False):
                            samples = generate_smiles(
                                drl_algorithm.model,
                                demo_data_gen,
                                init_args['gen_args'],
                                num_samples=n_to_generate)
                        predictions = expert_model(samples)[1]
                        mean_preds = np.mean(predictions)
                        try:
                            percentage_in_threshold = np.sum(
                                (predictions >= 7.0)) / len(predictions)
                        except:
                            percentage_in_threshold = 0.
                        per_valid = len(predictions) / n_to_generate
                        print(
                            f'Mean value of predictions = {mean_preds}, '
                            f'% of valid SMILES = {per_valid}, '
                            f'% in drug-like region={percentage_in_threshold}')
                        unbiased_smiles_mean_pred.append(float(baseline_score))
                        biased_smiles_mean_pred.append(float(demo_score))
                        gen_smiles_mean_pred.append(float(mean_preds))
                        tb_writer.add_scalars(
                            'qsar_score', {
                                'sampled': mean_preds,
                                'baseline': baseline_score,
                                'demo_data': demo_score
                            }, step_idx)
                        tb_writer.add_scalars(
                            'SMILES stats', {
                                'per. of valid': per_valid,
                                'per. above threshold': percentage_in_threshold
                            }, step_idx)
                        eval_dict = {}
                        eval_score = IReLeaSE.evaluate(
                            eval_dict, samples,
                            demo_data_gen.random_training_set_smiles(1000))

                        for k in eval_dict:
                            tracker.track(k, eval_dict[k], step_idx)
                        tracker.track('Average SMILES length',
                                      np.nanmean([len(s) for s in samples]),
                                      step_idx)
                        if bias_mode == 'max':
                            diff = mean_preds - demo_score
                        else:
                            diff = demo_score - mean_preds
                        score = np.exp(diff)
                        exp_avg.update(score)
                        tracker.track('score', score, step_idx)
                        if exp_avg.value > best_score:
                            best_model_wts = [
                                copy.deepcopy(
                                    drl_algorithm.model.state_dict()),
                                copy.deepcopy(irl_algorithm.model.state_dict())
                            ]
                            best_score = exp_avg.value
                        if best_score >= np.exp(0.):
                            print(
                                f'threshold reached, best score={mean_preds}, '
                                f'threshold={demo_score}, training completed')
                            break
                        if done_episodes == n_episodes:
                            print('Training completed!')
                            break

                    if batch_episodes < episodes_to_train:
                        continue

                    # Train models
                    print('Fitting models...')
                    irl_stmt = ''
                    if learn_irl:
                        irl_loss = irl_algorithm.fit(irl_trajectories)
                        tracker.track('irl_loss', irl_loss, step_idx)
                        irl_stmt = f'IRL loss = {irl_loss}, '
                    rl_loss = drl_algorithm.fit(exp_trajectories)
                    samples = generate_smiles(drl_algorithm.model,
                                              demo_data_gen,
                                              init_args['gen_args'],
                                              num_samples=3)
                    print(
                        f'{irl_stmt}RL loss = {rl_loss}, samples = {samples}')
                    tracker.track('agent_loss', rl_loss, step_idx)

                    # Reset
                    batch_episodes = 0
                    irl_trajectories.clear()
                    exp_trajectories.clear()

        if best_model_wts:
            drl_algorithm.model.load_state_dict(best_model_wts[0])
            irl_algorithm.model.load_state_dict(best_model_wts[1])
        duration = time.time() - start
        print('\nTraining duration: {:.0f}m {:.0f}s'.format(
            duration // 60, duration % 60))
        return {
            'model': [drl_algorithm.model, irl_algorithm.model],
            'score': round(best_score, 3),
            'epoch': done_episodes
        }
Ejemplo n.º 5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--cuda",
                        default=False,
                        action='store_true',
                        help='Enable CUDA')
    parser.add_argument("-n", "--name", required=True, help="Name of the run")
    args = parser.parse_args()
    device = torch.device(
        "cuda" if args.cuda and torch.cuda.is_available() else "cpu")

    save_path = os.path.join("saves", "d4pg-" + args.name)
    os.makedirs(save_path, exist_ok=True)

    env = gym.make(ENV_ID)
    test_env = gym.make(ENV_ID)

    act_net = model.D4PGActor(env.observation_space.shape[0],
                              env.action_space.shape[0]).to(device)
    crt_net = model.D4PGCritic(env.observation_space.shape[0],
                               env.action_space.shape[0], N_ATOMS, VMIN,
                               VMAX).to(device)
    tgt_act_net = ptan.agent.TargetNet(act_net)
    tgt_crt_net = ptan.agent.TargetNet(crt_net)

    writer = SummaryWriter(comment="-d4pg_" + args.name)

    agent = model.AgentD4PG(act_net, device=device)
    exp_source = ExperienceSourceFirstLast(env,
                                           agent,
                                           gamma=GAMMA,
                                           steps_count=REWARD_STEPS)
    buffer = PrioritizedReplayBuffer(exp_source, REPLAY_SIZE,
                                     PRIO_REPLAY_ALPHA)
    act_opt = optim.Adam(act_net.parameters(), lr=LEARNING_RATE)
    crt_opt = optim.Adam(crt_net.parameters(), lr=LEARNING_RATE)

    frame_idx = 0
    best_reward = 0
    with RewardTracker(writer) as tracker:
        with TBMeanTracker(writer, batch_size=10) as tb_tracker:
            while True:
                frame_idx += 1
                buffer.populate(1)
                beta = min(
                    1.0,
                    BETA_START + frame_idx * (1.0 - BETA_START) / BETA_FRAMES)

                rewards_steps = exp_source.pop_rewards_steps()
                if rewards_steps:
                    rewards, steps = zip(*rewards_steps)
                    tb_tracker.track('episode_steps', steps[0], frame_idx)
                    tracker.reward(rewards[0], frame_idx)

                if len(buffer) < REPLAY_INITIAL:
                    continue

                batch, batch_indices, batch_weights = buffer.sample(
                    BATCH_SIZE, beta)
                actor_loss, critic_loss, sample_prios = calc_loss(
                    batch, batch_weights, act_net, crt_net, tgt_act_net,
                    tgt_crt_net, device)

                train(actor_loss, critic_loss, act_net, crt_net, tgt_act_net,
                      tgt_crt_net, act_opt, crt_opt, device)

                tb_tracker.track('loss_actor', actor_loss, frame_idx)
                tb_tracker.track('loss_critic', critic_loss, frame_idx)

                buffer.update_priorities(batch_indices,
                                         sample_prios.data.cpu().numpy())

                if frame_idx % TEST_ITERS == 0:
                    ts = time.time()
                    rewards, steps = test(act_net, test_env, device=device)
                    print('Test done in %.2f sec, reward %.3f, steps %d' %
                          (time.time() - ts, rewards, steps))
                    writer.add_scalar('test_reward', rewards, frame_idx)
                    writer.add_scalar('test_steps', steps, frame_idx)
                    if best_reward is None or best_reward < rewards:
                        if best_reward is not None:
                            print('Best reward updated: %.3f -> %.3f' %
                                  (best_reward, rewards))
                            name = 'best_%+.3f_%d.dat' % (rewards, frame_idx)
                            fname = os.path.join(save_path, name)
                            torch.save(act_net.state_dict(), fname)
                        best_reward = rewards

    writer.close()
Ejemplo n.º 6
0
    def train(init_args, model_path=None, agent_net_name=None, reward_net_name=None, n_episodes=500,
              sim_data_node=None, tb_writer=None, is_hsearch=False, n_to_generate=200, learn_irl=True):
        tb_writer = tb_writer()
        agent = init_args['agent']
        probs_reg = init_args['probs_reg']
        drl_algorithm = init_args['drl_alg']
        irl_algorithm = init_args['irl_alg']
        reward_func = init_args['reward_func']
        gamma = init_args['gamma']
        episodes_to_train = init_args['episodes_to_train']
        expert_model = init_args['expert_model']
        demo_data_gen = init_args['demo_data_gen']
        unbiased_data_gen = init_args['unbiased_data_gen']
        best_model_wts = None
        exp_avg = ExpAverage(beta=0.6)
        best_score = -1.

        # load pretrained model
        if model_path and agent_net_name and reward_net_name:
            try:
                print('Loading pretrained model...')
                weights = IReLeaSE.load_model(model_path, agent_net_name)
                agent.model.load_state_dict(weights)
                print('Pretrained model loaded successfully!')
                reward_func.model.load_state_dict(IReLeaSE.load_model(model_path, reward_net_name))
                print('Reward model loaded successfully!')
            except:
                print('Pretrained model could not be loaded. Terminating prematurely.')
                return {'model': [drl_algorithm.actor,
                                  drl_algorithm.critic,
                                  irl_algorithm.model],
                        'score': round(best_score, 3),
                        'epoch': -1}

        start = time.time()

        # Begin simulation and training
        total_rewards = []
        trajectories = []
        done_episodes = 0
        batch_episodes = 0
        exp_trajectories = []
        step_idx = 0

        # collect mean predictions
        unbiased_smiles_mean_pred, biased_smiles_mean_pred, gen_smiles_mean_pred = [], [], []
        unbiased_smiles_mean_pred_data_node = DataNode('baseline_mean_vals', unbiased_smiles_mean_pred)
        biased_smiles_mean_pred_data_node = DataNode('biased_mean_vals', biased_smiles_mean_pred)
        gen_smiles_mean_pred_data_node = DataNode('gen_mean_vals', gen_smiles_mean_pred)
        if sim_data_node:
            sim_data_node.data = [unbiased_smiles_mean_pred_data_node,
                                  biased_smiles_mean_pred_data_node,
                                  gen_smiles_mean_pred_data_node]

        env = MoleculeEnv(actions=get_default_tokens(), reward_func=reward_func)
        exp_source = ExperienceSourceFirstLast(env, agent, gamma, steps_count=1, steps_delta=1)
        traj_prob = 1.
        exp_traj = []

        demo_score = np.mean(expert_model(demo_data_gen.random_training_set_smiles(1000))[1])
        baseline_score = np.mean(expert_model(unbiased_data_gen.random_training_set_smiles(1000))[1])
        # with contextlib.suppress(RuntimeError if is_hsearch else DummyException):
        try:
            with TBMeanTracker(tb_writer, 1) as tracker:
                for step_idx, exp in tqdm(enumerate(exp_source)):
                    exp_traj.append(exp)
                    traj_prob *= probs_reg.get(list(exp.state), exp.action)

                    if exp.last_state is None:
                        trajectories.append(Trajectory(terminal_state=EpisodeStep(exp.state, exp.action),
                                                       traj_prob=traj_prob))
                        exp_trajectories.append(exp_traj)  # for ExperienceFirstLast objects
                        exp_traj = []
                        traj_prob = 1.
                        probs_reg.clear()
                        batch_episodes += 1

                    new_rewards = exp_source.pop_total_rewards()
                    if new_rewards:
                        reward = new_rewards[0]
                        done_episodes += 1
                        total_rewards.append(reward)
                        mean_rewards = float(np.mean(total_rewards[-100:]))
                        tracker.track('mean_total_reward', mean_rewards, step_idx)
                        tracker.track('total_reward', reward, step_idx)
                        print(f'Time = {time_since(start)}, step = {step_idx}, reward = {reward:6.2f}, '
                              f'mean_100 = {mean_rewards:6.2f}, episodes = {done_episodes}')
                        with torch.set_grad_enabled(False):
                            samples = generate_smiles(drl_algorithm.model, demo_data_gen, init_args['gen_args'],
                                                      num_samples=n_to_generate)
                        predictions = expert_model(samples)[1]
                        mean_preds = np.nanmean(predictions)
                        if math.isnan(mean_preds) or math.isinf(mean_preds):
                            print(f'mean preds is {mean_preds}, terminating')
                            # best_score = -1.
                            break
                        try:
                            percentage_in_threshold = np.sum((predictions <= demo_score)) / len(predictions)
                        except:
                            percentage_in_threshold = 0.
                        per_valid = len(predictions) / n_to_generate
                        if per_valid < 0.2:
                            print(f'Percentage of valid SMILES is = {per_valid}. Terminating...')
                            # best_score = -1.
                            break
                        print(f'Mean value of predictions = {mean_preds}, % of valid SMILES = {per_valid}')
                        unbiased_smiles_mean_pred.append(float(baseline_score))
                        biased_smiles_mean_pred.append(float(demo_score))
                        gen_smiles_mean_pred.append(float(mean_preds))
                        tb_writer.add_scalars('qsar_score', {'sampled': mean_preds,
                                                             'baseline': baseline_score,
                                                             'demo_data': demo_score}, step_idx)
                        tb_writer.add_scalars('SMILES stats', {'per. of valid': per_valid,
                                                               'per. in drug-like region': percentage_in_threshold},
                                              step_idx)
                        eval_dict = {}
                        eval_score = IReLeaSE.evaluate(eval_dict, samples,
                                                       demo_data_gen.random_training_set_smiles(1000))
                        for k in eval_dict:
                            tracker.track(k, eval_dict[k], step_idx)
                        avg_len = np.nanmean([len(s) for s in samples])
                        tracker.track('Average SMILES length', np.nanmean([len(s) for s in samples]), step_idx)
                        d_penalty = eval_score < .5
                        s_penalty = avg_len < 20
                        diff = demo_score - mean_preds
                        # score = 3 * np.exp(diff) + np.log(per_valid + 1e-5) - s_penalty * np.exp(
                        #     diff) - d_penalty * np.exp(diff)
                        score = np.exp(diff)
                        # score = np.exp(diff) + np.mean([np.exp(per_valid), np.exp(percentage_in_threshold)])
                        if math.isnan(score) or math.isinf(score):
                            # best_score = -1.
                            print(f'Score is {score}, terminating.')
                            break
                        tracker.track('score', score, step_idx)
                        exp_avg.update(score)
                        if is_hsearch:
                            best_score = exp_avg.value
                        if exp_avg.value > best_score:
                            best_model_wts = [copy.deepcopy(drl_algorithm.actor.state_dict()),
                                              copy.deepcopy(drl_algorithm.critic.state_dict()),
                                              copy.deepcopy(irl_algorithm.model.state_dict())]
                            best_score = exp_avg.value
                        if best_score >= np.exp(0.):
                            print(f'threshold reached, best score={mean_preds}, '
                                  f'threshold={demo_score}, training completed')
                            break
                        if done_episodes == n_episodes:
                            print('Training completed!')
                            break

                    if batch_episodes < episodes_to_train:
                        continue

                    # Train models
                    print('Fitting models...')
                    irl_loss = 0.
                    # irl_loss = irl_algorithm.fit(trajectories) if learn_irl else 0.
                    rl_loss = drl_algorithm.fit(exp_trajectories)
                    samples = generate_smiles(drl_algorithm.model, demo_data_gen, init_args['gen_args'],
                                              num_samples=3)
                    print(f'IRL loss = {irl_loss}, RL loss = {rl_loss}, samples = {samples}')
                    tracker.track('irl_loss', irl_loss, step_idx)
                    tracker.track('critic_loss', rl_loss[0], step_idx)
                    tracker.track('agent_loss', rl_loss[1], step_idx)

                    # Reset
                    batch_episodes = 0
                    trajectories.clear()
                    exp_trajectories.clear()
        except Exception as e:
            print(str(e))
        if best_model_wts:
            drl_algorithm.actor.load_state_dict(best_model_wts[0])
            drl_algorithm.critic.load_state_dict(best_model_wts[1])
            irl_algorithm.model.load_state_dict(best_model_wts[2])
        duration = time.time() - start
        print('\nTraining duration: {:.0f}m {:.0f}s'.format(duration // 60, duration % 60))
        # if math.isinf(best_score) or math.isnan(best_score):
        #     best_score = -1.
        return {'model': [drl_algorithm.actor,
                          drl_algorithm.critic,
                          irl_algorithm.model],
                'score': round(best_score, 3),
                'epoch': done_episodes}
Ejemplo n.º 7
0
    def train(generator,
              optimizer,
              rnn_args,
              pretrained_net_path=None,
              pretrained_net_name=None,
              n_iters=5000,
              sim_data_node=None,
              tb_writer=None,
              is_hsearch=False,
              is_pretraining=True,
              grad_clipping=5):
        expert_model = rnn_args['expert_model']
        tb_writer = tb_writer()
        best_model_wts = generator.state_dict()
        best_score = -1000
        best_epoch = -1
        demo_data_gen = rnn_args['demo_data_gen']
        unbiased_data_gen = rnn_args['unbiased_data_gen']
        prior_data_gen = rnn_args['prior_data_gen']
        score_exp_avg = ExpAverage(beta=0.6)
        exp_type = rnn_args['exp_type']

        if is_pretraining:
            num_batches = math.ceil(prior_data_gen.file_len /
                                    prior_data_gen.batch_size)
        else:
            num_batches = math.ceil(demo_data_gen.file_len /
                                    demo_data_gen.batch_size)
        n_epochs = math.ceil(n_iters / num_batches)
        grad_stats = GradStats(generator, beta=0.)

        # learning rate decay schedulers
        scheduler = sch.StepLR(optimizer, step_size=100, gamma=0.02)

        # pred_loss functions
        criterion = nn.CrossEntropyLoss(
            ignore_index=prior_data_gen.char2idx[prior_data_gen.pad_symbol])

        # sub-nodes of sim data resource
        loss_lst = []
        train_loss_node = DataNode(label="train_loss", data=loss_lst)

        # collect mean predictions
        unbiased_smiles_mean_pred, biased_smiles_mean_pred, gen_smiles_mean_pred = [], [], []
        unbiased_smiles_mean_pred_data_node = DataNode(
            'baseline_mean_vals', unbiased_smiles_mean_pred)
        biased_smiles_mean_pred_data_node = DataNode('biased_mean_vals',
                                                     biased_smiles_mean_pred)
        gen_smiles_mean_pred_data_node = DataNode('gen_mean_vals',
                                                  gen_smiles_mean_pred)
        if sim_data_node:
            sim_data_node.data = [
                train_loss_node, unbiased_smiles_mean_pred_data_node,
                biased_smiles_mean_pred_data_node,
                gen_smiles_mean_pred_data_node
            ]

        # load pretrained model
        if pretrained_net_path and pretrained_net_name:
            print('Loading pretrained model...')
            weights = RNNBaseline.load_model(pretrained_net_path,
                                             pretrained_net_name)
            generator.load_state_dict(weights)
            print('Pretrained model loaded successfully!')

        start = time.time()
        try:
            demo_score = np.mean(
                expert_model(
                    demo_data_gen.random_training_set_smiles(1000))[1])
            baseline_score = np.mean(
                expert_model(
                    unbiased_data_gen.random_training_set_smiles(1000))[1])
            step_idx = Count()
            gen_data = prior_data_gen if is_pretraining else demo_data_gen
            with TBMeanTracker(tb_writer, 1) as tracker:
                mode = 'Pretraining' if is_pretraining else 'Fine tuning'
                n_epochs = 30
                for epoch in range(n_epochs):
                    epoch_losses = []
                    epoch_mean_preds = []
                    epoch_per_valid = []
                    with grad_stats:
                        for b in trange(
                                0,
                                num_batches,
                                desc=
                                f'Epoch {epoch + 1}/{n_epochs}, {mode} in progress...'
                        ):
                            inputs, labels = gen_data.random_training_set()
                            optimizer.zero_grad()

                            predictions = generator(inputs)[0]
                            predictions = predictions.permute(1, 0, -1)
                            predictions = predictions.contiguous().view(
                                -1, predictions.shape[-1])
                            labels = labels.contiguous().view(-1)

                            # calculate loss
                            loss = criterion(predictions, labels)
                            epoch_losses.append(loss.item())

                            # backward pass
                            loss.backward()
                            if grad_clipping:
                                torch.nn.utils.clip_grad_norm_(
                                    generator.parameters(), grad_clipping)
                            optimizer.step()
                            # scheduler.step()

                            # for sim data resource
                            n_to_generate = 200
                            with torch.set_grad_enabled(False):
                                samples = generate_smiles(
                                    generator,
                                    demo_data_gen,
                                    rnn_args,
                                    num_samples=n_to_generate,
                                    max_len=smiles_max_len)
                            samples_pred = expert_model(samples)[1]

                            # metrics
                            eval_dict = {}
                            eval_score = RNNBaseline.evaluate(
                                eval_dict, samples,
                                demo_data_gen.random_training_set_smiles(1000))
                            # TBoard info
                            tracker.track('loss', loss.item(),
                                          step_idx.IncAndGet())
                            for k in eval_dict:
                                tracker.track(f'{k}', eval_dict[k], step_idx.i)
                            mean_preds = np.mean(samples_pred)
                            epoch_mean_preds.append(mean_preds)
                            per_valid = len(samples_pred) / n_to_generate
                            epoch_per_valid.append(per_valid)
                            if exp_type == 'drd2':
                                per_qualified = float(
                                    len([v for v in samples_pred if v >= 0.8
                                         ])) / len(samples_pred)
                                score = mean_preds
                            elif exp_type == 'logp':
                                per_qualified = np.sum(
                                    (samples_pred >= 1.0)
                                    & (samples_pred < 5.0)) / len(samples_pred)
                                score = mean_preds
                            elif exp_type == 'jak2_max':
                                per_qualified = np.sum(
                                    (samples_pred >=
                                     demo_score)) / len(samples_pred)
                                diff = mean_preds - demo_score
                                score = np.exp(diff)
                            elif exp_type == 'jak2_min':
                                per_qualified = np.sum(
                                    (samples_pred <=
                                     demo_score)) / len(samples_pred)
                                diff = demo_score - mean_preds
                                score = np.exp(diff)
                            else:  # pretraining
                                score = per_valid  # -loss.item()
                                per_qualified = 0.
                            unbiased_smiles_mean_pred.append(
                                float(baseline_score))
                            biased_smiles_mean_pred.append(float(demo_score))
                            gen_smiles_mean_pred.append(float(mean_preds))
                            tb_writer.add_scalars(
                                'qsar_score', {
                                    'sampled': mean_preds,
                                    'baseline': baseline_score,
                                    'demo_data': demo_score
                                }, step_idx.i)
                            tb_writer.add_scalars(
                                'SMILES stats', {
                                    'per. of valid': per_valid,
                                    'per. of qualified': per_qualified
                                }, step_idx.i)
                            avg_len = np.nanmean([len(s) for s in samples])
                            tracker.track('Average SMILES length', avg_len,
                                          step_idx.i)

                            score_exp_avg.update(score)
                            if score_exp_avg.value > best_score:
                                best_model_wts = copy.deepcopy(
                                    generator.state_dict())
                                best_score = score_exp_avg.value
                                best_epoch = epoch

                            if step_idx.i > 0 and step_idx.i % 1000 == 0:
                                smiles = generate_smiles(
                                    generator=generator,
                                    gen_data=gen_data,
                                    init_args=rnn_args,
                                    num_samples=3,
                                    max_len=smiles_max_len)
                                print(f'Sample SMILES = {smiles}')
                        # End of mini=batch iterations.
                        print(
                            f'{time_since(start)}: Epoch {epoch + 1}/{n_epochs}, loss={np.mean(epoch_losses)},'
                            f'Mean value of predictions = {np.mean(epoch_mean_preds)}, '
                            f'% of valid SMILES = {np.mean(epoch_per_valid)}')

        except RuntimeError as e:
            print(str(e))

        duration = time.time() - start
        print('Model training duration: {:.0f}m {:.0f}s'.format(
            duration // 60, duration % 60))
        generator.load_state_dict(best_model_wts)
        return {
            'model': generator,
            'score': round(best_score, 3),
            'epoch': best_epoch
        }