示例#1
0
                    help='force two leg together')
parser.add_argument('--start-epoch', type=int, default=0, help='start-epoch')

if __name__ == '__main__':
    args = parser.parse_args()
    os.environ['OMP_NUM_THREADS'] = '1'
    torch.manual_seed(args.seed)

    num_inputs = args.feature
    num_actions = 18

    traffic_light = TrafficLight()
    counter = Counter()

    ac_net = ActorCritic(num_inputs, num_actions)
    opt_ac = optim.Adam(ac_net.parameters(), lr=args.lr)

    shared_grad_buffers = Shared_grad_buffers(ac_net)
    shared_obs_stats = Shared_obs_stats(num_inputs)

    if args.resume:
        print("=> loading checkpoint ")
        checkpoint = torch.load('../../7.87.t7')
        #checkpoint = torch.load('../../best.t7')
        args.start_epoch = checkpoint['epoch']
        #best_prec1 = checkpoint['best_prec1']
        ac_net.load_state_dict(checkpoint['state_dict'])
        opt_ac.load_state_dict(checkpoint['optimizer'])
        opt_ac.state = defaultdict(dict, opt_ac.state)
        #print(opt_ac)
        shared_obs_stats = checkpoint['obs']
示例#2
0
parser.add_argument('--use-joint-pol-val',
                    action='store_true',
                    help='whether to use combined policy and value nets')
args = parser.parse_args()

env = gym.make(args.env_name)

num_inputs = env.observation_space.shape[0]
num_actions = env.action_space.shape[0]

env.seed(args.seed)
torch.manual_seed(args.seed)

if args.use_joint_pol_val:
    ac_net = ActorCritic(num_inputs, num_actions)
    opt_ac = optim.Adam(ac_net.parameters(), lr=0.0003)
else:
    policy_net = GRU(num_inputs, num_actions)
    old_policy_net = GRU(num_inputs, num_actions)
    value_net = Value(num_inputs)
    opt_policy = optim.Adam(policy_net.parameters(), lr=0.0003)
    opt_value = optim.Adam(value_net.parameters(), lr=0.0003)


def create_batch_inputs(batch_states_list, batch_actions_list,
                        batch_advantages_list, batch_targets_list):
    lengths = []
    for states in batch_states_list:
        lengths.append(states.size(0))

    max_length = max(lengths)
示例#3
0
                    help='Imitation learning epochs')
parser.add_argument('--imitation-replay-size',
                    type=int,
                    default=1,
                    metavar='IRS',
                    help='Imitation learning trajectory replay size')
args = parser.parse_args()
torch.manual_seed(args.seed)
os.makedirs('results', exist_ok=True)

# Set up environment and models
env = CartPoleEnv()
env.seed(args.seed)
agent = ActorCritic(env.observation_space.shape[0], env.action_space.n,
                    args.hidden_size)
agent_optimiser = optim.RMSprop(agent.parameters(), lr=args.learning_rate)
if args.imitation:
    # Set up expert trajectories dataset
    expert_trajectories = torch.load('expert_trajectories.pth')
    expert_trajectories = {
        k: torch.cat([trajectory[k] for trajectory in expert_trajectories],
                     dim=0)
        for k in expert_trajectories[0].keys()
    }  # Flatten expert trajectories
    expert_trajectories = TransitionDataset(expert_trajectories)
    # Set up discriminator
    if args.imitation in ['AIRL', 'GAIL']:
        if args.imitation == 'AIRL':
            discriminator = AIRLDiscriminator(env.observation_space.shape[0],
                                              env.action_space.n,
                                              args.hidden_size,
示例#4
0
#num_steps        = 20
#mini_batch_size  = 5
#ppo_epochs       = 4

for c in range(num_classes):

    print("Learning Policy for class:", c)

    envs = [
        make_env(num_features, blackbox_model, c, max_nodes, min_nodes)
        for i in range(num_envs)
    ]
    envs = SubprocVecEnv(envs)

    model = ActorCritic(num_features, embedding_size)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    max_frames = 5000
    frame_idx = 0
    test_rewards = []

    state = envs.reset()

    early_stop = False

    #Save mean rewards per episode
    env_0_mean_rewards = []
    env_0_rewards = []

    while frame_idx < max_frames and not early_stop:
示例#5
0
class ActorCriticAgentUsingICM:
    def __init__(self, nb_actions, learning_rate, gamma, hidden_size,
                 model_input_size, entropy_coeff_start, entropy_coeff_end,
                 entropy_coeff_anneal, continuous):

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.num_actions = nb_actions

        self.gamma = gamma

        self.continuous = continuous

        self.learning_rate = learning_rate

        self.entropy_coefficient_start = entropy_coeff_start
        self.entropy_coefficient_end = entropy_coeff_end
        self.entropy_coefficient_anneal = entropy_coeff_anneal

        self.step_no = 0
        if self.continuous:
            self.model = ActorCriticContinuous(hidden_size=hidden_size,
                                               inputs=model_input_size,
                                               outputs=nb_actions).to(
                                                   self.device)
        else:
            self.model = ActorCritic(hidden_size=hidden_size,
                                     inputs=model_input_size,
                                     outputs=nb_actions).to(self.device)

        self.hidden_size = hidden_size
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=self.learning_rate)

        self.loss_function = torch.nn.MSELoss()

        self.memory = []

        self.ICM = ICM(model_input_size, nb_actions)
        self.ICM.train()

    # Get the current entropy coefficient value according to the start/end and annealing values
    def get_entropy_coefficient(self):
        entropy = self.entropy_coefficient_end
        if self.step_no < self.entropy_coefficient_anneal:
            entropy = self.entropy_coefficient_start - self.step_no * \
                ((self.entropy_coefficient_start - self.entropy_coefficient_end) /
                 self.entropy_coefficient_anneal)
        return entropy

    # select an action with policy
    def select_action(self, state):
        self.step_no += 1

        if self.continuous:
            action_mean, action_dev, state_value = self.model(state)
            action_dist = Normal(action_mean, action_dev)
        else:
            action_probs, state_value = self.model(state)
            action_dist = Categorical(action_probs)

        return action_dist, state_value

    def update_model(self):

        Gt = torch.tensor(0)

        policy_losses = []
        forward_losses = []
        inverse_losses = []
        value_losses = []
        entropy_loss = []
        returns = []

        # calculate the true value using rewards returned from the environment
        for (_, reward, _, _, _, _, _) in self.memory[::-1]:
            # calculate the discounted value
            Gt = reward + self.gamma * Gt

            returns.insert(0, Gt)

        returns = torch.tensor(returns)

        returns = (returns - returns.mean()) / (returns.std() + 1e-8)

        for (action_prob, _, state_value, entropy, state, next_state,
             action), Gt in zip(self.memory, returns):

            advantage = Gt.item() - state_value.item()

            # calculate actor (policy) loss
            policy_losses.append((-action_prob * advantage).mean())

            # calculate critic (value) loss using model loss function
            value_losses.append(
                self.loss_function(state_value, Gt.unsqueeze(0)))

            entropy_loss.append(-entropy)

            forward_losses.append(
                self.ICM.get_forward_loss(state, action, next_state))
            inverse_losses.append(
                self.ICM.get_inverse_loss(state, action, next_state))

        # reset gradients
        self.optimizer.zero_grad()
        self.ICM.optimizer.zero_grad()
        # sum up all the values of policy_losses and value_losses
        icm_loss = (1 - self.ICM.beta) * torch.stack(inverse_losses).mean(
        ) + self.ICM.beta * torch.stack(forward_losses).mean()

        loss = self.ICM.lambda_weight*(torch.stack(policy_losses).mean() + \
            torch.stack(value_losses).mean() + self.get_entropy_coefficient() * \
            torch.stack(entropy_loss).mean()) + icm_loss

        loss.backward()

        self.optimizer.step()
        self.ICM.optimizer.step()
        self.memory = []

        return loss.item()

    # save model
    def save(self, path, name):
        dirname = os.path.dirname(__file__)
        filename = os.path.join(dirname, os.path.join(path, name + ".pt"))
        torch.save(self.model.state_dict(), filename)

    # load a model
    def load(self, path):
        dirname = os.path.dirname(__file__)
        filename = os.path.join(dirname, path)
        self.model.load_state_dict(torch.load(filename))

    def cache(self, action_prob, reward, state_value, entropy, state,
              next_state, action):
        self.memory.append((action_prob, reward, state_value, entropy, state,
                            next_state, action))
示例#6
0
parser.add_argument('--test', action='store_true',
                    help='test ')
parser.add_argument('--feature', type=int, default=96, 
                    help='features num')


if __name__ == '__main__':
    args = parser.parse_args()
    os.environ['OMP_NUM_THREADS'] = '1'
    torch.manual_seed(args.seed)

    num_inputs = args.feature
    num_actions = 9

    ac_net = ActorCritic(num_inputs, num_actions)
    opt_ac = my_optim.SharedAdam(ac_net.parameters(), lr=args.lr)

    if args.resume:
        print("=> loading checkpoint ")
        checkpoint = torch.load('../models/kankan/best.t7')
        #args.start_epoch = checkpoint['epoch']
        #best_prec1 = checkpoint['best_prec1']
        ac_net.load_state_dict(checkpoint['state_dict'])
        #opt_ac.load_state_dict(checkpoint['optimizer'])
        print(ac_net)
        print("=> loaded checkpoint  (epoch {})"
                .format(checkpoint['epoch']))

    ac_net.share_memory()
    #opt_ac = my_optim.SharedAdam(ac_net.parameters(), lr=args.lr)
    opt_ac.share_memory()
示例#7
0
write = True
save = True
load = False
num_steps = 200
env_num = 128
worker_num = 2

mini_batch_size = 64
ppo_epochs = 4
if write:
    writer = SummaryWriter()
model = ActorCritic()
if load:
    model.load_state_dict(torch.load("weights.pt"))
model = model.cuda()
opt = torch.optim.AdamW(model.parameters())

envs_fns = [make_env for _ in range(worker_num)]
envs = SubprocWrapper(envs_fns, env_num)

step = 0
while True:
    log_probs = []
    values = []
    states = []
    invalids = []
    rewards = []
    actions = []
    terminals = []
    time_steps = []
示例#8
0
def main(args):
    print(f" Session ID: {args.uuid}")

    # logging
    log_dir = f'logs/{args.env_name}/{args.model_id}/{args.uuid}/'
    args_logger = setup_logger('args', log_dir, f'args.log')
    env_logger = setup_logger('env', log_dir, f'env.log')

    if args.debug:
        debug.packages()
    os.environ['OMP_NUM_THREADS'] = "1"
    if torch.cuda.is_available():
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        devices = ",".join([str(i) for i in range(torch.cuda.device_count())])
        os.environ["CUDA_VISIBLE_DEVICES"] = devices

    args_logger.info(vars(args))
    env_logger.info(vars(os.environ))

    env = create_atari_environment(args.env_name)

    shared_model = ActorCritic(env.observation_space.shape[0],
                               env.action_space.n)

    if torch.cuda.is_available():
        shared_model = shared_model.cuda()

    shared_model.share_memory()

    optimizer = SharedAdam(shared_model.parameters(), lr=args.lr)
    optimizer.share_memory()

    if args.load_model:  # TODO Load model before initializing optimizer
        checkpoint_file = f"{args.env_name}/{args.model_id}_{args.algorithm}_params.tar"
        checkpoint = restore_checkpoint(checkpoint_file)
        assert args.env_name == checkpoint['env'], \
            "Checkpoint is for different environment"
        args.model_id = checkpoint['id']
        args.start_step = checkpoint['step']
        print("Loading model from checkpoint...")
        print(f"Environment: {args.env_name}")
        print(f"      Agent: {args.model_id}")
        print(f"      Start: Step {args.start_step}")
        shared_model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    else:
        print(f"Environment: {args.env_name}")
        print(f"      Agent: {args.model_id}")

    torch.manual_seed(args.seed)

    print(
        FontColor.BLUE + \
        f"CPUs:    {mp.cpu_count(): 3d} | " + \
        f"GPUs: {None if not torch.cuda.is_available() else torch.cuda.device_count()}" + \
        FontColor.END
    )

    processes = []

    counter = mp.Value('i', 0)
    lock = mp.Lock()

    # Queue training processes
    num_processes = args.num_processes
    no_sample = args.non_sample  # count of non-sampling processes

    if args.num_processes > 1:
        num_processes = args.num_processes - 1

    samplers = num_processes - no_sample

    for rank in range(0, num_processes):
        device = 'cpu'
        if torch.cuda.is_available():
            device = 0  # TODO: Need to move to distributed to handle multigpu
        if rank < samplers:  # random action
            p = mp.Process(
                target=train,
                args=(rank, args, shared_model, counter, lock, optimizer,
                      device),
            )
        else:  # best action
            p = mp.Process(
                target=train,
                args=(rank, args, shared_model, counter, lock, optimizer,
                      device, False),
            )
        p.start()
        time.sleep(1.)
        processes.append(p)

    # Queue test process
    p = mp.Process(target=test,
                   args=(args.num_processes, args, shared_model, counter, 0))

    p.start()
    processes.append(p)

    for p in processes:
        p.join()
示例#9
0
def train(rank,
          args,
          shared_model,
          counter,
          lock,
          optimizer=None,
          device='cpu',
          select_sample=True):
    # torch.manual_seed(args.seed + rank)

    # logging
    log_dir = f'logs/{args.env_name}/{args.model_id}/{args.uuid}/'
    loss_logger = setup_logger('loss', log_dir, f'loss.log')
    # action_logger = setup_logger('actions', log_dir, f'actions.log')

    text_color = FontColor.RED if select_sample else FontColor.GREEN
    print(
        text_color +
        f"Process: {rank: 3d} | {'Sampling' if select_sample else 'Decision'} | Device: {str(device).upper()}",
        FontColor.END)

    env = create_atari_environment(args.env_name)
    observation_space = env.observation_space.shape[0]
    action_space = env.action_space.n

    # env.seed(args.seed + rank)

    model = ActorCritic(observation_space, action_space)
    if torch.cuda.is_available():
        model = model.cuda()
        model.device = device

    if optimizer is None:
        optimizer = optim.Adam(shared_model.parameters(), lr=args.lr)

    model.train()

    state = env.reset()
    state = torch.from_numpy(state)
    done = True

    for t in count(start=args.start_step):
        if t % args.save_interval == 0 and t > 0:
            save_checkpoint(shared_model, optimizer, args, t)

        # Sync shared model
        model.load_state_dict(shared_model.state_dict())

        if done:
            cx = torch.zeros(1, 512)
            hx = torch.zeros(1, 512)
        else:
            cx = cx.detach()
            hx = hx.detach()

        values = []
        log_probs = []
        rewards = []
        entropies = []

        episode_length = 0
        for step in range(args.num_steps):
            episode_length += 1

            value, logit, (hx, cx) = model((state.unsqueeze(0), (hx, cx)))

            prob = F.softmax(logit, dim=-1)
            log_prob = F.log_softmax(logit, dim=-1)
            entropy = -(log_prob * prob).sum(-1, keepdim=True)
            entropies.append(entropy)

            reason = ''

            if select_sample:
                rand = random.random()
                epsilon = get_epsilon(t)
                if rand < epsilon and args.greedy_eps:
                    action = torch.randint(0, action_space, (1, 1))
                    reason = 'uniform'

                else:
                    action = prob.multinomial(1)
                    reason = 'multinomial'

            else:
                action = prob.max(-1, keepdim=True)[1]
                reason = 'choice'

            # action_logger.info({
            #     'rank': rank,
            #     'action': action.item(),
            #     'reason': reason,
            #     })

            if torch.cuda.is_available():
                action = action.cuda()
                value = value.cuda()

            log_prob = log_prob.gather(-1, action)

            # action_out = ACTIONS[args.move_set][action.item()]

            state, reward, done, info = env.step(action.item())

            done = done or episode_length >= args.max_episode_length
            reward = max(min(reward, 50), -50)  # h/t @ArvindSoma

            with lock:
                counter.value += 1

            if done:
                episode_length = 0
                state = env.reset()

            state = torch.from_numpy(state)
            values.append(value)
            log_probs.append(log_prob)
            rewards.append(reward)

            if done:
                break

        R = torch.zeros(1, 1)
        if not done:
            value, _, _ = model((state.unsqueeze(0), (hx, cx)))
            R = value.data

        values.append(R)

        loss = gae(R, rewards, values, log_probs, entropies, args)

        loss_logger.info({
            'episode': t,
            'rank': rank,
            'sampling': select_sample,
            'loss': loss.item()
        })

        optimizer.zero_grad()

        (loss).backward()

        nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

        ensure_shared_grads(model, shared_model)

        optimizer.step()
示例#10
0
def main(dataset,
         pretrain_max_epoch,
         max_epoch,
         learning_rate,
         weight_decay,
         max_pretrain_grad_norm,
         max_grad_norm,
         batch_size,
         embedding_size,
         rnn_input_size,
         rnn_hidden_size,
         hidden_size,
         bottleneck_size,
         entropy_penalty,
         gamma,
         alpha,
         nonlinear_func='tanh',
         value_weight=0.5,
         reward_function='exf1',
         label_order='freq2rare',
         input_dropout_prob=0.2,
         dropout_prob=0.5,
         num_layers=1,
         cv_fold=0,
         seed=None,
         fixed_label_seq_pretrain=False):

    data_loaders, configs = prepare_exp(dataset,
                                        max_epoch,
                                        learning_rate,
                                        weight_decay,
                                        batch_size,
                                        embedding_size,
                                        rnn_input_size,
                                        rnn_hidden_size,
                                        hidden_size,
                                        bottleneck_size,
                                        nonlinear_func=nonlinear_func,
                                        dropout_prob=dropout_prob,
                                        num_layers=num_layers,
                                        label_order=label_order,
                                        entropy_penalty=entropy_penalty,
                                        value_weight=value_weight,
                                        reward_function=reward_function,
                                        gamma=gamma,
                                        alpha=alpha,
                                        cv_fold=cv_fold,
                                        seed=seed)

    train_loader, sub_train_loader, valid_loader, test_loader = data_loaders
    opt_config, data_config, model_config = configs

    BOS_ID = train_loader.dataset.get_start_label_id()
    EOS_ID = train_loader.dataset.get_stop_label_id()
    is_sparse_data = train_loader.dataset.is_sparse_dataset()

    criterion = nn.NLLLoss(ignore_index=0, reduction='none')
    model = ActorCritic(model_config)
    if device.type == 'cuda':
        model = model.cuda()
    optimizer = optim.Adam(model.parameters(),
                           lr=opt_config['learning_rate'],
                           weight_decay=weight_decay)
    env = Environment(model_config)
    bipartition_eval_functions, ranking_evaluation_functions = load_evaluation_functions(
    )

    current_time = datetime.now().strftime('%b%d_%H-%M-%S')
    model_arch_info = 'emb_{}_rnn_{}_hid_{}_bot_{}_inpdp_{}_dp_{}_{}'.format(
        embedding_size, rnn_hidden_size, hidden_size, bottleneck_size,
        input_dropout_prob, dropout_prob, nonlinear_func)
    rl_info = 'alpha_{}_gamma_{}_vw_{}_reward_{}_ent_{}'.format(
        alpha, gamma, value_weight, reward_function, entropy_penalty)
    optim_info = 'lr_{}_decay_{}_norm_{}-{}_bs_{}_epoch_{}-{}_fold_{}'.format(
        learning_rate, weight_decay, max_pretrain_grad_norm, max_grad_norm,
        batch_size, pretrain_max_epoch, max_epoch, cv_fold)

    if fixed_label_seq_pretrain and max_epoch == 0:
        # baseline models
        summary_comment = '_'.join([
            current_time, 'baseline', label_order, model_arch_info, optim_info
        ])
    else:
        summary_comment = '_'.join(
            [current_time, 'proposed', model_arch_info, rl_info, optim_info])

    summary_log_dir = os.path.join('runs', dataset, summary_comment)
    bipartition_model_save_path = os.path.join(
        'models', dataset, summary_comment + '_bipartition.pth')
    ranking_model_save_path = os.path.join('models', dataset,
                                           summary_comment + '_ranking.pth')

    writer = SummaryWriter(log_dir=summary_log_dir)

    n_batches = 0
    for batch_idx, (data, targets) in enumerate(train_loader):
        n_batches += 1

    num_param_updates = 0
    best_bipartition_valid_score = -np.inf
    best_ranking_valid_score = -np.inf
    max_epoch = opt_config['max_epoch']

    input_dropout = nn.Dropout(p=input_dropout_prob)

    # pretrain or only supervised learning with a fixed label ordering
    for epoch in range(pretrain_max_epoch):

        print('==== {} ===='.format(epoch))
        avg_rewards = []
        for batch_idx, (data, targets) in enumerate(train_loader):

            data, targets = prepare_minibatch(
                data,
                targets,
                train_loader.dataset.get_feature_dim(),
                is_sparse_data,
                drop_EOS=label_order == 'mblp')
            batch_size = len(targets)
            assert data.shape[0] == batch_size, '{}\t{}'.format(
                data.shape[0], batch_size)

            data = input_dropout(data)

            if label_order != 'mblp':
                target_length = np.array(list(map(len, targets)))
                max_length = int(np.max(target_length))
                targets_ = np.zeros((max_length, batch_size), dtype=np.int64)

                for i in range(batch_size):
                    targets_[:len(targets[i]), i] = targets[i]

                targets = torch.tensor(targets_,
                                       dtype=torch.int64,
                                       device=device,
                                       requires_grad=False)
            else:
                max_target_length = np.max(np.array(list(map(len, targets))))
                max_sampling_steps = int(max_target_length * 1.5)

                env.clear_episode_temp_data()
                gen_actions_per_episode = []
                rewards_per_episode = []

                model = model.eval()
                prev_states = model.init_hidden(data, device)
                prev_actions = torch.tensor([BOS_ID] * batch_size,
                                            dtype=torch.int64,
                                            device=device)

                for t in range(
                        max_sampling_steps):  # no infinite loop while learning

                    model_outputs, states = model(data,
                                                  prev_actions,
                                                  prev_states,
                                                  state_value_grad=False)
                    gen_actions, _, done = env.step(model_outputs)

                    gen_actions_per_episode.append(
                        gen_actions.data.cpu().numpy())

                    if done:
                        break

                    prev_actions = gen_actions
                    prev_states = states

                # gen_actions_per_episode: (batch_size, max_trials) # cols can be smaller.
                gen_actions_per_episode = np.array(gen_actions_per_episode).T

                # sort labels according to model predictions
                targets_ = convert_labelset2seq(targets,
                                                gen_actions_per_episode,
                                                EOS_ID)
                targets = torch.tensor(targets_,
                                       dtype=torch.int64,
                                       device=device,
                                       requires_grad=False)

                del gen_actions_per_episode

            model = model.train()
            prev_states = model.init_hidden(data, device)
            prev_actions = torch.tensor([BOS_ID] * batch_size,
                                        dtype=torch.int64,
                                        device=device,
                                        requires_grad=False)
            dropout_masks = create_dropout_mask(
                model_config.dropout_prob, batch_size,
                model_config.embedding_size * 2, model_config.rnn_hidden_size)

            losses = []
            for t in range(targets.size(0)):  # no infinite loop while learning
                model_outputs, states = model(data,
                                              prev_actions,
                                              prev_states,
                                              dropout_masks=dropout_masks,
                                              state_value_grad=False)

                logits = model_outputs[0]
                log_probs = F.log_softmax(logits, dim=-1)
                target_t = targets[t]

                losses.append(criterion(log_probs, target_t))

                prev_actions = target_t
                prev_states = states

            # loss: (seq_len, batch_size)
            loss = torch.stack(losses, dim=0)
            loss = torch.sum(loss, dim=0).mean()

            optimizer.zero_grad()
            loss.backward()

            output_str = '{}/Before gradient norm'.format(dataset)
            print_param_norm(model.parameters(), writer, output_str,
                             num_param_updates)

            # torch.nn.utils.clip_grad_value_(model.parameters(), 100)
            if max_pretrain_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               max_pretrain_grad_norm)

            output_str = '{}/After gradient norm'.format(dataset)
            print_param_norm(model.parameters(), writer, output_str,
                             num_param_updates)

            optimizer.step()

            num_param_updates += 1

        results = evaluation(
            OrderedDict([('sub_train', sub_train_loader),
                         ('valid', valid_loader), ('test', test_loader)]),
            model, env, bipartition_eval_functions,
            ranking_evaluation_functions, model_config.max_trials)

        print_result_summary(results, writer, dataset, epoch)

        for split_name, scores in results.items():
            if split_name is 'valid':
                if scores['example f1 score'] > best_bipartition_valid_score:
                    best_bipartition_valid_scores = scores['example f1 score']
                    save_model(epoch, model, optimizer,
                               bipartition_model_save_path)

                if scores['nDCG_k'][-1] > best_ranking_valid_score:
                    best_ranking_valid_scores = scores['nDCG_k'][-1]
                    save_model(epoch, model, optimizer,
                               ranking_model_save_path)

    def update_alpha(epoch, xlimit=6, alpha_max=1):
        updated_alpha = 1 / (
            1 + float(np.exp(xlimit - 2 * xlimit / float(max_epoch) * epoch)))
        updated_alpha = min(updated_alpha, alpha_max)
        return updated_alpha

    del optimizer

    # joint learning
    rl_optimizer = optim.Adam(model.parameters(),
                              lr=opt_config['learning_rate'],
                              weight_decay=weight_decay)
    for epoch in range(max_epoch):
        if alpha == 'auto':
            alpha_e = update_alpha(epoch)
        else:
            assert float(alpha) >= 0 and float(alpha) <= 1
            alpha_e = float(alpha)

        print('==== {} ===='.format(epoch + pretrain_max_epoch))
        avg_rewards = []
        for batch_idx, (data, targets) in enumerate(train_loader):

            model = model.train()
            data, targets = prepare_minibatch(
                data, targets, train_loader.dataset.get_feature_dim(),
                is_sparse_data)
            batch_size = len(targets)
            assert data.shape[0] == batch_size, '{}\t{}'.format(
                data.shape[0], batch_size)

            data = input_dropout(data)

            dropout_masks = create_dropout_mask(
                model_config.dropout_prob, batch_size,
                model_config.embedding_size * 2, model_config.rnn_hidden_size)
            prev_states = model.init_hidden(data, device)
            prev_actions = torch.tensor([BOS_ID] * batch_size,
                                        dtype=torch.int64,
                                        device=device,
                                        requires_grad=False)

            max_target_length = np.max(np.array(list(map(len, targets))))
            max_sampling_steps = int(max_target_length * 1.5)

            env.clear_episode_temp_data()
            gen_actions_per_episode = []
            rewards_per_episode = []

            for t in range(
                    max_sampling_steps):  # no infinite loop while learning

                model_outputs, states = model(data,
                                              prev_actions,
                                              prev_states,
                                              dropout_masks=dropout_masks)
                gen_actions, rewards, done = env.step(model_outputs, targets)

                gen_actions_per_episode.append(gen_actions.data.cpu().numpy())
                rewards_per_episode.append(rewards)

                if done:
                    break

                prev_actions = gen_actions
                prev_states = states

            num_non_empty = np.array([len(t) > 0 for t in targets]).sum()
            r = np.stack(rewards_per_episode,
                         axis=1).sum(1).sum() / num_non_empty
            avg_rewards.append(r)

            ps_loss, adv_collection = calculate_loss(env, model_config)
            writer.add_scalar('{}/avg_advantages'.format(dataset),
                              adv_collection.mean().data.cpu().numpy(),
                              num_param_updates)

            # gen_actions_per_episode: (batch_size, max_trials) # cols can be smaller.
            gen_actions_per_episode = np.array(gen_actions_per_episode).T

            # sort labels according to model predictions
            targets_ = convert_labelset2seq(targets, gen_actions_per_episode,
                                            EOS_ID)
            targets = torch.tensor(targets_,
                                   dtype=torch.int64,
                                   device=device,
                                   requires_grad=False)

            del gen_actions_per_episode

            prev_states = model.init_hidden(data, device)
            prev_actions = torch.tensor([BOS_ID] * batch_size,
                                        dtype=torch.int64,
                                        device=device,
                                        requires_grad=False)
            dropout_masks = create_dropout_mask(
                model_config.dropout_prob, batch_size,
                model_config.embedding_size * 2, model_config.rnn_hidden_size)

            losses = []
            for t in range(targets.size(0)):  # no infinite loop while learning
                model_outputs, states = model(data,
                                              prev_actions,
                                              prev_states,
                                              dropout_masks=dropout_masks,
                                              state_value_grad=False)
                logits = model_outputs[0]
                log_probs = F.log_softmax(logits, dim=-1)
                target_t = targets[t]

                losses.append(criterion(log_probs, target_t))

                prev_actions = target_t
                prev_states = states

            # loss: (seq_len, batch_size)
            sup_loss = torch.stack(losses, dim=0).sum(0)

            loss = alpha_e * ps_loss + (1 - alpha_e) * sup_loss
            loss = loss.mean()

            rl_optimizer.zero_grad()
            loss.backward()

            output_str = '{}/Before gradient norm'.format(dataset)
            print_param_norm(model.parameters(), writer, output_str,
                             num_param_updates)

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

            output_str = '{}/After gradient norm'.format(dataset)
            print_param_norm(model.parameters(), writer, output_str,
                             num_param_updates)

            rl_optimizer.step()

            num_param_updates += 1

        results = evaluation(
            OrderedDict([('sub_train', sub_train_loader),
                         ('valid', valid_loader), ('test', test_loader)]),
            model, env, bipartition_eval_functions,
            ranking_evaluation_functions, model_config.max_trials)

        print_result_summary(results, writer, dataset,
                             epoch + pretrain_max_epoch)

        for split_name, scores in results.items():
            if split_name is 'valid':
                if scores['example f1 score'] > best_bipartition_valid_score:
                    best_bipartition_valid_scores = scores['example f1 score']
                    save_model(epoch + pretrain_max_epoch, model, rl_optimizer,
                               bipartition_model_save_path)

                if scores['nDCG_k'][-1] > best_ranking_valid_score:
                    best_ranking_valid_scores = scores['nDCG_k'][-1]
                    save_model(epoch + pretrain_max_epoch, model, rl_optimizer,
                               ranking_model_save_path)

    writer.close()