예제 #1
0
 def test_embeddings(self):
     x, y = gen_data.random_training_set(batch_size=bz)
     encoder = Encoder(gen_data.n_characters, 128,
                       gen_data.char2idx[gen_data.pad_symbol])
     x = encoder((x, ))
     assert (x.ndim == 3)
     print(f'x.shape = {x.shape}')
예제 #2
0
 def test_molecule_mcts(self):
     d_model = 8
     hidden_size = 16
     num_layers = 2
     encoder = Encoder(gen_data.n_characters,
                       d_model,
                       gen_data.char2idx[gen_data.pad_symbol],
                       return_tuple=False)
     rnn = RewardNetRNN(d_model,
                        hidden_size,
                        num_layers,
                        bidirectional=True,
                        unit_type='gru')
     env = MoleculeEnv(
         gen_data,
         RewardFunction(reward_net=torch.nn.Sequential(encoder, rnn),
                        policy=lambda x: gen_data.all_characters[
                            np.random.randint(gen_data.n_characters)],
                        actions=gen_data.all_characters))
     rewards = []
     for i in range(5):
         env.render()
         action = env.action_space.sample()
         s_prime, reward, done, info = env.step(action)
         rewards.append(reward)
         if done:
             env.reset()
             break
     print(f'rewards: {rewards}')
예제 #3
0
 def test_mol_env(self):
     d_model = 8
     hidden_size = 16
     num_layers = 1
     encoder = Encoder(gen_data.n_characters,
                       d_model,
                       gen_data.char2idx[gen_data.pad_symbol],
                       return_tuple=True)
     rnn = RewardNetRNN(d_model,
                        hidden_size,
                        num_layers,
                        bidirectional=True,
                        unit_type='gru')
     reward_net = torch.nn.Sequential(encoder, rnn)
     env = MoleculeEnv(
         gen_data,
         RewardFunction(reward_net=reward_net,
                        policy=lambda x: gen_data.all_characters[
                            np.random.randint(gen_data.n_characters)],
                        actions=gen_data.all_characters))
     print(f'sample action: {env.action_space.sample()}')
     print(f'sample observation: {env.observation_space.sample()}')
     s = env.reset()
     for i in range(5):
         env.render()
         action = env.action_space.sample()
         print(f'action = {action}')
         s_prime, reward, done, info = env.step(action)
         if done:
             env.reset()
             break
예제 #4
0
 def test_input_equals_output_embeddings(self):
     x, y = gen_data.random_training_set(batch_size=bz)
     encoder = Encoder(gen_data.n_characters, 128,
                       gen_data.char2idx[gen_data.pad_symbol])
     lin_out = LinearOut(encoder.embeddings_weight)
     x = encoder(x)
     x_out = lin_out(x)
     assert x.shape == x_out.shape
예제 #5
0
    def test_stack_rnn_cell(self):
        x, y = gen_data.random_training_set(batch_size=bz)
        d_model = 128
        hidden_size = 16
        stack_width = 10
        stack_depth = 20
        num_layers = 1
        num_dir = 2
        encoder = Encoder(gen_data.n_characters, d_model,
                          gen_data.char2idx[gen_data.pad_symbol])
        x = encoder(x)
        rnn_cells = []
        in_dim = d_model
        cell_type = 'gru'
        for _ in range(num_layers):
            rnn_cells.append(
                StackRNNCell(in_dim,
                             hidden_size,
                             has_stack=True,
                             unit_type=cell_type,
                             stack_depth=stack_depth,
                             stack_width=stack_width))
            in_dim = hidden_size * num_dir
        rnn_cells = torch.nn.ModuleList(rnn_cells)

        h0 = init_hidden(num_layers=num_layers,
                         batch_size=bz,
                         hidden_size=hidden_size,
                         num_dir=num_dir)
        c0 = init_hidden(num_layers=num_layers,
                         batch_size=bz,
                         hidden_size=hidden_size,
                         num_dir=num_dir)
        s0 = init_stack(bz, stack_width, stack_depth)

        seq_length = x.shape[0]
        hidden_outs = torch.zeros(num_layers, num_dir, seq_length, bz,
                                  hidden_size)
        if cell_type == 'lstm':
            cell_outs = torch.zeros(num_layers, num_dir, seq_length, bz,
                                    hidden_size)
        assert 0 <= num_dir <= 2
        for l in range(num_layers):
            for d in range(num_dir):
                h, c, stack = h0[l, d, :], c0[l, d, :], s0
                if d == 0:
                    indices = range(x.shape[0])
                else:
                    indices = reversed(range(x.shape[0]))
                for i in indices:
                    x_t = x[i, :, :]
                    hx, stack = rnn_cells[l](x_t, h, c, stack)
                    if cell_type == 'lstm':
                        hidden_outs[l, d, i, :, :] = hx[0]
                        cell_outs[l, d, i, :, :] = hx[1]
                    else:
                        hidden_outs[l, d, i, :, :] = hx
예제 #6
0
 def test_positional_encodings(self):
     x, y = gen_data.random_training_set(batch_size=bz)
     encoder = Encoder(gen_data.n_characters, 128,
                       gen_data.char2idx[gen_data.pad_symbol])
     x = encoder(x)
     enc_shape = x.shape
     pe = PositionalEncoding(128, dropout=.2, max_len=500)
     x = pe(x)
     assert (x.shape == enc_shape)
     print(f'x.shape = {x.shape}')
예제 #7
0
    def initialize(hparams, gen_data, *args, **kwargs):
        gen_data.set_batch_size(hparams['batch_size'])
        # Create main model
        encoder = Encoder(vocab_size=gen_data.n_characters,
                          d_model=hparams['d_model'],
                          padding_idx=gen_data.char2idx[gen_data.pad_symbol],
                          dropout=hparams['dropout'],
                          return_tuple=True)
        # Create RNN layers
        rnn_layers = []
        has_stack = True
        for i in range(1, hparams['num_layers'] + 1):
            rnn_layers.append(
                StackRNN(layer_index=i,
                         input_size=hparams['d_model'],
                         hidden_size=hparams['d_model'],
                         has_stack=has_stack,
                         unit_type=hparams['unit_type'],
                         stack_width=hparams['stack_width'],
                         stack_depth=hparams['stack_depth'],
                         k_mask_func=encoder.k_padding_mask))
            if hparams['num_layers'] > 1:
                rnn_layers.append(StackedRNNDropout(hparams['dropout']))
                rnn_layers.append(StackedRNNLayerNorm(hparams['d_model']))

        model = nn.Sequential(
            encoder,
            *rnn_layers,
            RNNLinearOut(
                out_dim=gen_data.n_characters,
                hidden_size=hparams['d_model'],
                bidirectional=False,
                # encoder=encoder,
                # dropout=hparams['dropout'],
                bias=True))
        if use_cuda:
            model = model.cuda()
        optimizer = parse_optimizer(hparams, model)
        rnn_args = {
            'num_layers': hparams['num_layers'],
            'hidden_size': hparams['d_model'],
            'num_dir': 1,
            'device': device,
            'has_stack': has_stack,
            'has_cell': hparams['unit_type'] == 'lstm',
            'stack_width': hparams['stack_width'],
            'stack_depth': hparams['stack_depth']
        }
        return model, optimizer, gen_data, rnn_args
예제 #8
0
 def test_reward_rnn(self):
     x, y = gen_data.random_training_set(batch_size=bz)
     d_model = 8
     hidden_size = 16
     num_layers = 2
     encoder = Encoder(gen_data.n_characters,
                       d_model,
                       gen_data.char2idx[gen_data.pad_symbol],
                       return_tuple=False)
     x = encoder([x])
     rnn = RewardNetRNN(d_model,
                        hidden_size,
                        num_layers,
                        bidirectional=True,
                        unit_type='lstm')
     r = rnn(x)
     print(f'reward: {r}')
예제 #9
0
 def test_stack_decoder_layer(self):
     x, y = gen_data.random_training_set(batch_size=bz)
     d_model = 128
     d_hidden = 10
     s_width = 16
     s_depth = 20
     encoder = Encoder(gen_data.n_characters, 128,
                       gen_data.char2idx[gen_data.pad_symbol])
     x = encoder(x)
     pe = PositionalEncoding(d_model, dropout=.2, max_len=500)
     x = pe(x)
     h0 = init_hidden_2d(x.shape[1], x.shape[0], d_hidden)
     s0 = init_stack_2d(x.shape[1], x.shape[0], s_depth, s_width)
     stack_decoder = StackDecoderLayer(d_model=d_model,
                                       num_heads=1,
                                       stack_depth=s_depth,
                                       stack_width=s_width,
                                       dropout=.1)
     out = stack_decoder((x, s0))
     assert (len(out) == 3)
예제 #10
0
 def test_stack_rnn(self):
     x, y = gen_data.random_training_set(batch_size=bz)
     d_model = 12
     hidden_size = 16
     stack_width = 10
     stack_depth = 20
     unit_type = 'lstm'
     num_layers = 2
     hidden_states = [
         get_initial_states(bz, hidden_size, 1, stack_depth, stack_width,
                            unit_type) for _ in range(num_layers)
     ]
     encoder = Encoder(gen_data.n_characters, d_model,
                       gen_data.char2idx[gen_data.pad_symbol])
     x = encoder(x)
     stack_rnn_1 = StackRNN(1,
                            d_model,
                            hidden_size,
                            True,
                            'gru',
                            stack_width,
                            stack_depth,
                            k_mask_func=encoder.k_padding_mask)
     stack_rnn_2 = StackRNN(2,
                            hidden_size,
                            hidden_size,
                            True,
                            'gru',
                            stack_width,
                            stack_depth,
                            k_mask_func=encoder.k_padding_mask)
     outputs = stack_rnn_1([x] + hidden_states)
     outputs = stack_rnn_2(outputs)
     assert len(outputs) > 1
     linear = RNNLinearOut(
         4,
         hidden_size,
         bidirectional=False,
     )
     x = linear(outputs)
     print(x[0].shape)
    def initialize(hparams, demo_data_gen, unbiased_data_gen, prior_data_gen,
                   *args, **kwargs):
        # Embeddings provider
        encoder = Encoder(
            vocab_size=demo_data_gen.n_characters,
            d_model=hparams['d_model'],
            padding_idx=demo_data_gen.char2idx[demo_data_gen.pad_symbol],
            dropout=hparams['dropout'],
            return_tuple=True)

        # Agent entities
        rnn_layers = []
        has_stack = True
        for i in range(1, hparams['agent_params']['num_layers'] + 1):
            rnn_layers.append(
                StackRNN(layer_index=i,
                         input_size=hparams['d_model'],
                         hidden_size=hparams['d_model'],
                         has_stack=has_stack,
                         unit_type=hparams['agent_params']['unit_type'],
                         stack_width=hparams['agent_params']['stack_width'],
                         stack_depth=hparams['agent_params']['stack_depth'],
                         k_mask_func=encoder.k_padding_mask))
            if hparams['agent_params']['num_layers'] > 1:
                rnn_layers.append(StackedRNNDropout(hparams['dropout']))
                rnn_layers.append(StackedRNNLayerNorm(hparams['d_model']))
        agent_net = nn.Sequential(
            encoder, *rnn_layers,
            RNNLinearOut(out_dim=demo_data_gen.n_characters,
                         hidden_size=hparams['d_model'],
                         bidirectional=False,
                         bias=True))
        agent_net = agent_net.to(device)
        optimizer_agent_net = parse_optimizer(hparams['agent_params'],
                                              agent_net)
        selector = MolEnvProbabilityActionSelector(
            actions=demo_data_gen.all_characters)
        probs_reg = StateActionProbRegistry()
        init_state_args = {
            'num_layers': hparams['agent_params']['num_layers'],
            'hidden_size': hparams['d_model'],
            'stack_depth': hparams['agent_params']['stack_depth'],
            'stack_width': hparams['agent_params']['stack_width'],
            'unit_type': hparams['agent_params']['unit_type']
        }
        agent = PolicyAgent(model=agent_net,
                            action_selector=selector,
                            states_preprocessor=seq2tensor,
                            initial_state=agent_net_hidden_states_func,
                            initial_state_args=init_state_args,
                            apply_softmax=True,
                            probs_registry=probs_reg,
                            device=device)
        drl_alg = REINFORCE(model=agent_net,
                            optimizer=optimizer_agent_net,
                            initial_states_func=agent_net_hidden_states_func,
                            initial_states_args=init_state_args,
                            prior_data_gen=prior_data_gen,
                            device=device,
                            xent_lambda=hparams['xent_lambda'],
                            gamma=hparams['gamma'],
                            grad_clipping=hparams['reinforce_max_norm'],
                            lr_decay_gamma=hparams['lr_decay_gamma'],
                            lr_decay_step=hparams['lr_decay_step_size'],
                            delayed_reward=not hparams['use_monte_carlo_sim'])

        # Reward function entities
        reward_net = nn.Sequential(
            encoder,
            RewardNetRNN(
                input_size=hparams['d_model'],
                hidden_size=hparams['reward_params']['d_model'],
                num_layers=hparams['reward_params']['num_layers'],
                bidirectional=hparams['reward_params']['bidirectional'],
                use_attention=hparams['reward_params']['use_attention'],
                dropout=hparams['dropout'],
                unit_type=hparams['reward_params']['unit_type'],
                use_smiles_validity_flag=hparams['reward_params']
                ['use_validity_flag']))
        reward_net = reward_net.to(device)

        expert_model = XGBPredictor(hparams['expert_model_dir'])
        true_reward_func = get_jak2_max_reward if hparams[
            'bias_mode'] == 'max' else get_jak2_min_reward
        reward_function = RewardFunction(
            reward_net,
            mc_policy=agent,
            actions=demo_data_gen.all_characters,
            device=device,
            use_mc=hparams['use_monte_carlo_sim'],
            mc_max_sims=hparams['monte_carlo_N'],
            expert_func=expert_model,
            no_mc_fill_val=hparams['no_mc_fill_val'],
            true_reward_func=true_reward_func,
            use_true_reward=hparams['use_true_reward'])
        optimizer_reward_net = parse_optimizer(hparams['reward_params'],
                                               reward_net)
        demo_data_gen.set_batch_size(
            hparams['reward_params']['demo_batch_size'])
        irl_alg = GuidedRewardLearningIRL(
            reward_net,
            optimizer_reward_net,
            demo_data_gen,
            k=hparams['reward_params']['irl_alg_num_iter'],
            agent_net=agent_net,
            agent_net_init_func=agent_net_hidden_states_func,
            agent_net_init_func_args=init_state_args,
            device=device)

        init_args = {
            'agent': agent,
            'probs_reg': probs_reg,
            'drl_alg': drl_alg,
            'irl_alg': irl_alg,
            'reward_func': reward_function,
            'gamma': hparams['gamma'],
            'episodes_to_train': hparams['episodes_to_train'],
            'expert_model': expert_model,
            'demo_data_gen': demo_data_gen,
            'unbiased_data_gen': unbiased_data_gen,
            'gen_args': {
                'num_layers': hparams['agent_params']['num_layers'],
                'hidden_size': hparams['d_model'],
                'num_dir': 1,
                'stack_depth': hparams['agent_params']['stack_depth'],
                'stack_width': hparams['agent_params']['stack_width'],
                'has_stack': has_stack,
                'has_cell': hparams['agent_params']['unit_type'] == 'lstm',
                'device': device
            }
        }
        return init_args
예제 #12
0
def initialize(hparams, demo_data_gen, unbiased_data_gen, has_critic):
    # Embeddings provider
    encoder = Encoder(vocab_size=demo_data_gen.n_characters, d_model=hparams['d_model'],
                      padding_idx=demo_data_gen.char2idx[demo_data_gen.pad_symbol],
                      dropout=hparams['dropout'], return_tuple=True).eval()

    # Agent entities
    rnn_layers = []
    has_stack = True
    for i in range(1, hparams['agent_params']['num_layers'] + 1):
        rnn_layers.append(StackRNN(layer_index=i,
                                   input_size=hparams['d_model'],
                                   hidden_size=hparams['d_model'],
                                   has_stack=has_stack,
                                   unit_type=hparams['agent_params']['unit_type'],
                                   stack_width=hparams['agent_params']['stack_width'],
                                   stack_depth=hparams['agent_params']['stack_depth'],
                                   k_mask_func=encoder.k_padding_mask))
        if hparams['agent_params']['num_layers'] > 1:
            rnn_layers.append(StackedRNNDropout(hparams['dropout']))
            rnn_layers.append(StackedRNNLayerNorm(hparams['d_model']))
    agent_net = nn.Sequential(encoder,
                              *rnn_layers,
                              RNNLinearOut(out_dim=demo_data_gen.n_characters,
                                           hidden_size=hparams['d_model'],
                                           bidirectional=False,
                                           bias=True))
    agent_net = agent_net.to(device).eval()
    init_state_args = {'num_layers': hparams['agent_params']['num_layers'],
                       'hidden_size': hparams['d_model'],
                       'stack_depth': hparams['agent_params']['stack_depth'],
                       'stack_width': hparams['agent_params']['stack_width'],
                       'unit_type': hparams['agent_params']['unit_type']}
    if has_critic:
        critic = nn.Sequential(encoder,
                               CriticRNN(hparams['d_model'], hparams['critic_params']['d_model'],
                                         unit_type=hparams['critic_params']['unit_type'],
                                         dropout=hparams['critic_params']['dropout'],
                                         num_layers=hparams['critic_params']['num_layers']))
        critic = critic.to(device).eval()
    else:
        critic = None

    # Reward function entities
    reward_net_rnn = RewardNetRNN(input_size=hparams['d_model'], hidden_size=hparams['reward_params']['d_model'],
                                  num_layers=hparams['reward_params']['num_layers'],
                                  bidirectional=hparams['reward_params']['bidirectional'],
                                  use_attention=hparams['reward_params']['use_attention'],
                                  dropout=hparams['reward_params']['dropout'],
                                  unit_type=hparams['reward_params']['unit_type'],
                                  use_smiles_validity_flag=hparams['reward_params']['use_validity_flag'])
    reward_net = nn.Sequential(encoder,
                               reward_net_rnn)
    reward_net = reward_net.to(device)
    # expert_model = RNNPredictor(hparams['expert_model_params'], device)
    demo_data_gen.set_batch_size(hparams['reward_params']['demo_batch_size'])

    init_args = {'agent_net': agent_net,
                 'critic_net': critic,
                 'reward_net': reward_net,
                 'reward_net_rnn': reward_net_rnn,
                 'encoder': encoder.eval(),
                 'gamma': hparams['gamma'],
                 # 'expert_model': expert_model,
                 'demo_data_gen': demo_data_gen,
                 'unbiased_data_gen': unbiased_data_gen,
                 'init_hidden_states_args': init_state_args,
                 'gen_args': {'num_layers': hparams['agent_params']['num_layers'],
                              'hidden_size': hparams['d_model'],
                              'num_dir': 1,
                              'stack_depth': hparams['agent_params']['stack_depth'],
                              'stack_width': hparams['agent_params']['stack_width'],
                              'has_stack': has_stack,
                              'has_cell': hparams['agent_params']['unit_type'] == 'lstm',
                              'device': device}}
    return init_args
예제 #13
0
    def test_policy_net(self):
        d_model = 8
        hidden_size = 16
        num_layers = 1
        stack_width = 10
        stack_depth = 20
        unit_type = 'lstm'

        # Create a function to provide initial hidden states
        def hidden_states_func(batch_size=1):
            return [
                get_initial_states(batch_size, hidden_size, 1, stack_depth,
                                   stack_width, unit_type)
                for _ in range(num_layers)
            ]

        # Encoder to map character indices to embeddings
        encoder = Encoder(gen_data.n_characters,
                          d_model,
                          gen_data.char2idx[gen_data.pad_symbol],
                          return_tuple=True)

        # Create agent network
        stack_rnn = StackRNN(1,
                             d_model,
                             hidden_size,
                             True,
                             'lstm',
                             stack_width,
                             stack_depth,
                             k_mask_func=encoder.k_padding_mask)
        stack_linear = RNNLinearOut(gen_data.n_characters,
                                    hidden_size,
                                    bidirectional=False)
        agent_net = torch.nn.Sequential(encoder, stack_rnn, stack_linear)

        # Create agent
        selector = MolEnvProbabilityActionSelector(
            actions=gen_data.all_characters)
        probs_reg = StateActionProbRegistry()
        agent = PolicyAgent(model=agent_net,
                            action_selector=selector,
                            states_preprocessor=seq2tensor,
                            initial_state=hidden_states_func,
                            apply_softmax=True,
                            probs_registry=probs_reg,
                            device='cpu')

        # Reward function model
        rnn = RewardNetRNN(d_model,
                           hidden_size,
                           num_layers,
                           bidirectional=True,
                           unit_type='gru')
        reward_net = torch.nn.Sequential(encoder, rnn)
        reward_function = RewardFunction(reward_net=reward_net,
                                         mc_policy=agent,
                                         actions=gen_data.all_characters)

        # Create molecule generation environment
        env = MoleculeEnv(gen_data.all_characters, reward_function)

        # Ptan ops for aggregating experiences
        exp_source = ExperienceSourceFirstLast(env, agent, gamma=0.97)

        rl_alg = REINFORCE(agent_net, torch.optim.Adam(agent_net.parameters()),
                           hidden_states_func)
        gen_data.set_batch_size(1)
        irl_alg = GuidedRewardLearningIRL(reward_net,
                                          torch.optim.Adam(
                                              reward_net.parameters()),
                                          demo_gen_data=gen_data)

        # Begin simulation and training
        batch_states, batch_actions, batch_qvals = [], [], []
        traj_prob = 1.
        for step_idx, exp in enumerate(exp_source):
            batch_states.append(exp.state)
            batch_actions.append(exp.action)
            batch_qvals.append(exp.reward)
            traj_prob *= probs_reg.get(list(exp.state), exp.action)

            print(
                f'state = {exp.state}, action = {exp.action}, reward = {exp.reward}, next_state = {exp.last_state}'
            )
            if step_idx == 5:
                break
예제 #14
0
    def initialize(hparams, gen_data, *args, **kwargs):
        gen_data.set_batch_size(hparams['batch_size'])

        # Create stack-augmented transformer (Decoder) layer(s)
        encoder = Encoder(vocab_size=gen_data.n_characters,
                          d_model=hparams['d_model'],
                          padding_idx=gen_data.char2idx[gen_data.pad_symbol],
                          dropout=hparams['dropout'],
                          return_tuple=True)
        attn_layers = []
        for i in range(hparams['attn_layers']):
            attn_layers.append(
                StackDecoderLayer(d_model=hparams['d_model'],
                                  num_heads=hparams['attn_heads'],
                                  stack_depth=hparams['stack_depth'],
                                  stack_width=hparams['stack_width'],
                                  d_ff=hparams['d_ff'],
                                  dropout=hparams['dropout'],
                                  k_mask_func=encoder.k_padding_mask,
                                  use_memory=hparams['has_stack']))

        # Create classifier layers (post-attention layers)
        classifier_layers = []
        p = hparams['d_model']
        for dim in hparams['lin_dims']:
            classifier_layers.append(nn.Linear(p, dim))
            classifier_layers.append(nn.LayerNorm(dim))
            classifier_layers.append(nn.ReLU())
            classifier_layers.append(nn.Dropout(hparams['dropout']))
            p = dim
        classifier_layers.append(nn.Linear(p, gen_data.n_characters))
        # classifier_layers.append(LinearOut(encoder.embeddings_weight, p, hparams['d_model'], hparams['dropout']))

        # Create main model
        model = nn.Sequential(
            encoder,
            PositionalEncoding(d_model=hparams['d_model'],
                               dropout=hparams['dropout']),
            # AttentionInitialize(d_hidden=hparams['d_model'],
            #                     s_width=hparams['stack_width'],
            #                     s_depth=hparams['stack_depth'],
            #                     dvc=f'{device}:{dvc_id}'),
            *attn_layers,
            AttentionTerminal(),
            *classifier_layers)
        if use_cuda:
            model = model.cuda()

        optimizer = parse_optimizer(hparams, model)
        # optimizer = get_std_opt(model, hparams['d_model'])
        # optimizer = AttentionOptimizer(model_size=hparams['d_model'],
        #                                factor=2,
        #                                warmup=4000,
        #                                optimizer=parse_optimizer(hparams, model))
        init_args = {
            'stack_width': hparams['stack_width'],
            'stack_depth': hparams['stack_depth'],
            'device': f'{device}:{dvc_id}',
            'has_stack': hparams['has_stack']
        }
        return model, optimizer, gen_data, init_args
    def initialize(hparams, data_gens, *args, **kwargs):
        for k in data_gens:
            data_gens[k].set_batch_size(hparams['batch_size'])
        gen_data = data_gens['prior_data']
        # Create main model
        encoder = Encoder(vocab_size=gen_data.n_characters,
                          d_model=hparams['d_model'],
                          padding_idx=gen_data.char2idx[gen_data.pad_symbol],
                          dropout=hparams['dropout'],
                          return_tuple=True)
        # Create RNN layers
        rnn_layers = []
        has_stack = True
        for i in range(1, hparams['num_layers'] + 1):
            rnn_layers.append(
                StackRNN(layer_index=i,
                         input_size=hparams['d_model'],
                         hidden_size=hparams['d_model'],
                         has_stack=has_stack,
                         unit_type=hparams['unit_type'],
                         stack_width=hparams['stack_width'],
                         stack_depth=hparams['stack_depth'],
                         k_mask_func=encoder.k_padding_mask))
            if hparams['num_layers'] > 1:
                rnn_layers.append(StackedRNNDropout(hparams['dropout']))
                rnn_layers.append(StackedRNNLayerNorm(hparams['d_model']))

        model = nn.Sequential(
            encoder,
            *rnn_layers,
            RNNLinearOut(
                out_dim=gen_data.n_characters,
                hidden_size=hparams['d_model'],
                bidirectional=False,
                # encoder=encoder,
                # dropout=hparams['dropout'],
                bias=True))
        if use_cuda:
            model = model.cuda()
        optimizer = parse_optimizer(hparams, model)
        rnn_args = {
            'num_layers': hparams['num_layers'],
            'hidden_size': hparams['d_model'],
            'num_dir': 1,
            'device': device,
            'has_stack': has_stack,
            'has_cell': hparams['unit_type'] == 'lstm',
            'stack_width': hparams['stack_width'],
            'stack_depth': hparams['stack_depth'],
            'demo_data_gen': data_gens['demo_data'],
            'unbiased_data_gen': data_gens['unbiased_data'],
            'prior_data_gen': data_gens['prior_data'],
            'expert_model': {
                'pretraining': DummyPredictor(),
                'drd2': RNNPredictor(hparams['drd2'], device, True),
                'logp': RNNPredictor(hparams['logp'], device),
                'jak2_max': XGBPredictor(hparams['jak2']),
                'jak2_min': XGBPredictor(hparams['jak2'])
            }.get(hparams['exp_type']),
            'exp_type': hparams['exp_type'],
        }
        return model, optimizer, rnn_args