Exemplo n.º 1
0
def train(target_vars, saver, sess, logger, dataloader, resume_iter, logdir):
    X = target_vars['X']
    Y = target_vars['Y']
    X_NOISE = target_vars['X_NOISE']
    train_op = target_vars['train_op']
    energy_pos = target_vars['energy_pos']
    energy_neg = target_vars['energy_neg']
    loss_energy = target_vars['loss_energy']
    loss_ml = target_vars['loss_ml']
    loss_total = target_vars['total_loss']
    gvs = target_vars['gvs']
    x_grad = target_vars['x_grad']
    x_grad_first = target_vars['x_grad_first']
    x_off = target_vars['x_off']
    temp = target_vars['temp']
    x_mod = target_vars['x_mod']
    LABEL = target_vars['LABEL']
    LABEL_POS = target_vars['LABEL_POS']
    weights = target_vars['weights']
    test_x_mod = target_vars['test_x_mod']
    eps = target_vars['eps_begin']
    label_ent = target_vars['label_ent']

    if FLAGS.use_attention:
        gamma = weights[0]['atten']['gamma']
    else:
        gamma = tf.zeros(1)

    val_output = [test_x_mod]

    gvs_dict = dict(gvs)

    log_output = [
        train_op,
        energy_pos,
        energy_neg,
        eps,
        loss_energy,
        loss_ml,
        loss_total,
        x_grad,
        x_off,
        x_mod,
        gamma,
        x_grad_first,
        label_ent,
        *gvs_dict.keys()]
    output = [train_op, x_mod]

    replay_buffer = ReplayBuffer(10000)
    itr = resume_iter
    x_mod = None
    gd_steps = 1

    dataloader_iterator = iter(dataloader)
    best_inception = 0.0

    for epoch in range(FLAGS.epoch_num):
        for data_corrupt, data, label in dataloader:
            data_corrupt = data_corrupt_init = data_corrupt.numpy()
            data_corrupt_init = data_corrupt.copy()

            data = data.numpy()
            label = label.numpy()

            label_init = label.copy()

            if FLAGS.mixup:
                idx = np.random.permutation(data.shape[0])
                lam = np.random.beta(1, 1, size=(data.shape[0], 1, 1, 1))
                data = data * lam + data[idx] * (1 - lam)

            if FLAGS.replay_batch and (x_mod is not None):
                replay_buffer.add(compress_x_mod(x_mod))

                if len(replay_buffer) > FLAGS.batch_size:
                    replay_batch = replay_buffer.sample(FLAGS.batch_size)
                    replay_batch = decompress_x_mod(replay_batch)
                    replay_mask = (
                        np.random.uniform(
                            0,
                            FLAGS.rescale,
                            FLAGS.batch_size) > 0.05)
                    data_corrupt[replay_mask] = replay_batch[replay_mask]

            if FLAGS.pcd:
                if x_mod is not None:
                    data_corrupt = x_mod

            feed_dict = {X_NOISE: data_corrupt, X: data, Y: label}

            if FLAGS.cclass:
                feed_dict[LABEL] = label
                feed_dict[LABEL_POS] = label_init

            if itr % FLAGS.log_interval == 0:
                _, e_pos, e_neg, eps, loss_e, loss_ml, loss_total, x_grad, x_off, x_mod, gamma, x_grad_first, label_ent, * \
                    grads = sess.run(log_output, feed_dict)

                kvs = {}
                kvs['e_pos'] = e_pos.mean()
                kvs['e_pos_std'] = e_pos.std()
                kvs['e_neg'] = e_neg.mean()
                kvs['e_diff'] = kvs['e_pos'] - kvs['e_neg']
                kvs['e_neg_std'] = e_neg.std()
                kvs['temp'] = temp
                kvs['loss_e'] = loss_e.mean()
                kvs['eps'] = eps.mean()
                kvs['label_ent'] = label_ent
                kvs['loss_ml'] = loss_ml.mean()
                kvs['loss_total'] = loss_total.mean()
                kvs['x_grad'] = np.abs(x_grad).mean()
                kvs['x_grad_first'] = np.abs(x_grad_first).mean()
                kvs['x_off'] = x_off.mean()
                kvs['iter'] = itr
                kvs['gamma'] = gamma

                for v, k in zip(grads, [v.name for v in gvs_dict.values()]):
                    kvs[k] = np.abs(v).max()

                string = "Obtained a total of "
                for key, value in kvs.items():
                    string += "{}: {}, ".format(key, value)

                if hvd.rank() == 0:
                    print(string)
                    logger.writekvs(kvs)
                else:
                    _, x_mod = sess.run(output, feed_dict)

                if itr % FLAGS.save_interval == 0 and hvd.rank() == 0:
                    saver.save(
                        sess,
                        osp.join(
                            FLAGS.logdir,
                            FLAGS.exp,
                            'model_{}'.format(itr)))

                if itr % FLAGS.test_interval == 0 and hvd.rank() == 0 and FLAGS.dataset != '2d':
                    try_im = x_mod
                    orig_im = data_corrupt.squeeze()
                    actual_im = rescale_im(data)

                    orig_im = rescale_im(orig_im)
                    try_im = rescale_im(try_im).squeeze()

                    for i, (im, t_im, actual_im_i) in enumerate(
                            zip(orig_im[:20], try_im[:20], actual_im)):
                        shape = orig_im.shape[1:]
                        new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:]))
                        size = shape[1]
                        new_im[:, :size] = im
                        new_im[:, size:2 * size] = t_im
                        new_im[:, 2 * size:] = actual_im_i

                        log_image(
                            new_im, logger, 'train_gen_{}'.format(itr), step=i)

                    test_im = x_mod

                    try:
                        data_corrupt, data, label = next(dataloader_iterator)
                    except BaseException:
                        dataloader_iterator = iter(dataloader)
                        data_corrupt, data, label = next(dataloader_iterator)

                    data_corrupt = data_corrupt.numpy()

                    if FLAGS.replay_batch and (
                            x_mod is not None) and len(replay_buffer) > 0:
                        replay_batch = replay_buffer.sample(FLAGS.batch_size)
                        replay_batch = decompress_x_mod(replay_batch)
                        replay_mask = (
                            np.random.uniform(
                                0, 1, (FLAGS.batch_size)) > 0.05)
                        data_corrupt[replay_mask] = replay_batch[replay_mask]

                    if FLAGS.dataset == 'cifar10' or FLAGS.dataset == 'imagenet' or FLAGS.dataset == 'imagenetfull':
                        n = 128

                        if FLAGS.dataset == "imagenetfull":
                            n = 32

                        if len(replay_buffer) > n:
                            data_corrupt = decompress_x_mod(replay_buffer.sample(n))
                        elif FLAGS.dataset == 'imagenetfull':
                            data_corrupt = np.random.uniform(
                                0, FLAGS.rescale, (n, 128, 128, 3))
                        else:
                            data_corrupt = np.random.uniform(
                                0, FLAGS.rescale, (n, 32, 32, 3))

                        if FLAGS.dataset == 'cifar10':
                            label = np.eye(10)[np.random.randint(0, 10, (n))]
                        else:
                            label = np.eye(1000)[
                                np.random.randint(
                                    0, 1000, (n))]

                    feed_dict[X_NOISE] = data_corrupt

                    feed_dict[X] = data

                    if FLAGS.cclass:
                        feed_dict[LABEL] = label

                    test_x_mod = sess.run(val_output, feed_dict)

                    try_im = test_x_mod
                    orig_im = data_corrupt.squeeze()
                    actual_im = rescale_im(data.numpy())

                    orig_im = rescale_im(orig_im)
                    try_im = rescale_im(try_im).squeeze()

                    for i, (im, t_im, actual_im_i) in enumerate(
                            zip(orig_im[:20], try_im[:20], actual_im)):

                        shape = orig_im.shape[1:]
                        new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:]))
                        size = shape[1]
                        new_im[:, :size] = im
                        new_im[:, size:2 * size] = t_im
                        new_im[:, 2 * size:] = actual_im_i
                        log_image(
                            new_im, logger, 'val_gen_{}'.format(itr), step=i)

                    score, std = get_inception_score(list(try_im), splits=1)
                    print(
                        "Inception score of {} with std of {}".format(
                            score, std))
                    kvs = {}
                    kvs['inception_score'] = score
                    kvs['inception_score_std'] = std
                    logger.writekvs(kvs)

                    if score > best_inception:
                        best_inception = score
                        saver.save(
                            sess,
                            osp.join(
                                FLAGS.logdir,
                                FLAGS.exp,
                                'model_best'))

            if itr > 60000 and FLAGS.dataset == "mnist":
                assert False
            itr += 1

    saver.save(sess, osp.join(FLAGS.logdir, FLAGS.exp, 'model_{}'.format(itr)))
Exemplo n.º 2
0
def main(config):
    tf.compat.v1.reset_default_graph()
    env_name = config['run']['env']
    env = gym.make(env_name)
    np.random.seed(config['random_seed'])
    tf.compat.v1.set_random_seed(config['random_seed'])
    env.seed(config['random_seed'])

    batch_size = config['train']['batch_size']
    state_dim = env.observation_space.shape

    # Use action_dim[0]: (a_dim,) --> a_dim
    action_dim = env.action_space.shape[0]

    # Define action boundaries for continuous but bounded action space
    action_low = env.action_space.low
    action_high = env.action_space.high

    print(f'-------- {env_name} --------')
    print('STATE DIM: ', state_dim)
    print('ACTION DIM: ', action_dim)
    print('ACTION LOW: ', action_low)
    print('ACTION HIGH: ', action_high)
    print('----------------------------')
    
    # Initialize memory for experience replay
    replay_buffer = ReplayBuffer(config['train']['replay_buffer_size'], config['random_seed'])
        
    # Set up summary TF operations
    summary_ops, summary_vars = build_summaries()

    with tf.compat.v1.Session() as sess:
        # sess.run(tf.compat.v1.global_variables_initializer())
        writer = tf.compat.v1.summary.FileWriter(config['output']['summary_dir'], sess.graph)

        # Use agent_factory to build the agent using the algorithm specified in the config file.
        Agent = agent_factory(config['agent']['model'])
        agent = Agent(config, state_dim, action_dim, action_low, action_high, sess)

        sess.run(tf.compat.v1.global_variables_initializer())

        for i in range(int(config['train']['max_episodes'])):
            s = env.reset()
            episode_reward = 0
            episode_average_max_q = 0

            for j in range(int(config['train']['max_episode_len'])):
                if config['run']['render_env'] == True:
                    env.render()

                # 1. Predict an action to take
                a = agent.actor.predict_action(np.expand_dims(s, 0))

                # 2. Use action to take step in environment and receive next step, reward, etc.
                s2, r, terminal, info = env.step(a[0])

                # 3. Update the replay buffer with the most recent experience
                replay_buffer.add(np.reshape(s, state_dim), np.reshape(a, action_dim), r,
                                  np.reshape(s2, state_dim), terminal)

                # 4. When there are enough experiences in the replay buffer, sample minibatches of training experiences
                if replay_buffer.size() > batch_size:
                    experience = replay_buffer.sample_batch(batch_size)

                    # Train current behavioural networks
                    predicted_Q_value = agent.train_networks(experience)

                    # Update for logging
                    episode_average_max_q += np.amax(predicted_Q_value)

                    # Update target networks
                    agent.update_target_networks()

                # Update information for next step
                s = s2
                episode_reward += r

                if terminal:
                    summary_str = sess.run(summary_ops, feed_dict={
                        summary_vars[0]: episode_reward,
                        summary_vars[1]: episode_average_max_q / float(j)
                        })

                    writer.add_summary(summary_str, i)
                    writer.flush()

                    print('| Reward: {:d} | Episode: {:d} | Qmax: {:.4f}'.format(int(episode_reward), i, (episode_average_max_q / float(j))))
                    
                    break

    if config['run']['use_gym_monitor'] == True:
        env.monitor.close()
Exemplo n.º 3
0
def training_loop(hyperparameters):
    print(f"Starting training with hyperparameters: {hyperparameters}")
    save_path = hyperparameters["save_path"]
    load_path = hyperparameters["load_path"]

    # create the save path and save hyperparameter configuration
    if not os.path.exists(save_path):
        os.mkdir(save_path)
    else:
        a = input("Warning, Directory already exists. Dou want to continue?")
        if a not in ["Y","y"]:
            raise Exception("Path already exists, please start with another path.")

    with open(save_path+ "/parameters.json", "w") as f:
        json.dump(hyperparameters, f)

    # general configurations
    state_dim=18
    action_dim=4
    max_action=1
    iterations=hyperparameters["max_iterations"]
    batch_size=hyperparameters["batch_size"]
    max_episodes=hyperparameters["max_episodes"]
    train_mode = hyperparameters["train_mode"]
    closeness_factor=hyperparameters["closeness_factor"]
    c = closeness_factor

    # init the agent
    agent1 = TD3Agent([state_dim + action_dim, 256, 256, 1],
                        [state_dim, 256, 256, action_dim],
                        optimizer=hyperparameters["optimizer"],
                        policy_noise=hyperparameters["policy_noise"],
                        policy_noise_clip=hyperparameters["policy_noise_clip"],
                        gamma=hyperparameters["gamma"],
                        delay=hyperparameters["delay"],
                        tau=hyperparameters["tau"],
                        lr=hyperparameters["lr"],
                        max_action=max_action,
                        weight_decay=hyperparameters["weight_decay"])

    # load the agent if given
    loaded_state=False
    if load_path:
        agent1.load(load_path)
        loaded_state=True

    # define opponent
    if hyperparameters["self_play"]:
        agent2=agent1
    else:
        agent2 = h_env.BasicOpponent(weak=hyperparameters["weak_agent"])

    # load enviroment and replaybuffer
    replay_buffer = ReplayBuffer(state_dim, action_dim)

    if train_mode == "defense":
        env = h_env.HockeyEnv(mode=h_env.HockeyEnv.TRAIN_DEFENSE)
    elif train_mode == "shooting":
        env = h_env.HockeyEnv(mode=h_env.HockeyEnv.TRAIN_SHOOTING)
    else:
        env = h_env.HockeyEnv()


    # add figure to plot later
    if hyperparameters["plot_performance"]:
        fig, (ax_loss, ax_reward) = plt.subplots(2)
        ax_loss.set_xlim(0, max_episodes)
        ax_loss.set_ylim(0, 20)
        ax_reward.set_xlim(0, max_episodes)
        ax_reward.set_ylim(-30, 20)

    with HiddenPrints():
    # first sample enough data to start:
        obs_last = env.reset()
        for i in range(batch_size*100):
            a1 = env.action_space.sample()[:4] if not loaded_state else agent1.act(env.obs_agent_two())
            a2 = agent2.act(env.obs_agent_two())
            obs, r, d, info = env.step(np.hstack([a1,a2]))
            done = 1 if d else 0
            replay_buffer.add(obs_last, a1, obs, r, done)
            obs_last=obs
            if d:
                obs_last = env.reset()

    print("Finished collection of data prior to training")

    # tracking of performance
    episode_critic_loss=[]
    episode_rewards=[]
    win_count=[]
    if not os.path.isfile(save_path + "/performance.csv"):
        pd.DataFrame(data={"Episode_rewards":[], "Episode_critic_loss":[], "Win/Loss":[]}).to_csv(save_path + "/performance.csv", sep=",", index=False)

    # Then start training
    for episode_count in range(max_episodes+1):
        obs_last = env.reset()
        total_reward=0
        critic_loss=[]

        for i in range(iterations):
            # run the enviroment
            with HiddenPrints():
                with torch.no_grad():
                    a1 =  agent1.act(env.obs_agent_two()) + np.random.normal(loc=0, scale=hyperparameters["exploration_noise"], size=action_dim)
                a2 = agent2.act(env.obs_agent_two())
                obs, r, d, info = env.step(np.hstack([a1,a2]))
            total_reward+=r
            done = 1 if d else 0

            # mopify reward with cloeness to puck reward
            if hyperparameters["closeness_decay"]:
                c = closeness_factor *(1 - episode_count/max_episodes)
            newreward = r + c * info["reward_closeness_to_puck"] 

            # add to replaybuffer
            replay_buffer.add(obs_last, a1, obs, newreward, done)
            obs_last=obs
            
            # sample minibatch and train
            states, actions, next_states, reward, done = replay_buffer.sample(batch_size)
            loss = agent1.train(states, actions, next_states, reward, done)
            critic_loss.append(loss.detach().numpy())

            # if done, finish episode
            if d:
                episode_rewards.append(total_reward)
                episode_critic_loss.append(np.mean(critic_loss))
                win_count.append(info["winner"])
                print(f"Episode {episode_count} finished after {i} steps with a total reward of {total_reward}")
                
                # Online plotting
                if hyperparameters["plot_performance"] and episode_count>40 :
                    ax_loss.plot(list(range(-1, episode_count-29)), moving_average(episode_critic_loss, 30), 'r-')
                    ax_reward.plot(list(range(-1, episode_count-29)), moving_average(episode_rewards, 30), "r-")
                    plt.draw()
                    plt.pause(1e-17)

                break
        
        # Intermediate evaluation of win/loss and saving of model
        if episode_count % 500 ==0 and episode_count != 0:
                print(f"The agents win ratio in the last 500 episodes was {win_count[-500:].count(1)/500}")
                print(f"The agents loose ratio in the last 500 episodes was {win_count[-500:].count(-1)/500}")
                try:
                    agent1.save(save_path)
                    print("saved model")
                except Exception:
                    print("Saving Failed model failed")
                pd.DataFrame(data={"Episode_rewards": episode_rewards[-500:], "Episode_critic_loss": episode_critic_loss[-500:], "Win/Loss": win_count[-500:]}).to_csv(save_path + "/performance.csv", sep=",", index=False, mode="a", header=False)
                    
    print(f"Finished training with a final mean reward of {np.mean(episode_rewards[-500:])}")





    # plot the performance summary
    if hyperparameters["plot_performance_summary"]:
            try:
                fig, (ax1, ax2) = plt.subplots(2)
                x = list(range(len(episode_critic_loss)))
                coef = np.polyfit(x, episode_critic_loss,1)
                poly1d_fn = np.poly1d(coef)
                ax1.plot(episode_critic_loss)
                ax1.plot(poly1d_fn(list(range(len(episode_critic_loss)))))


                x = list(range(len(episode_rewards)))
                coef = np.polyfit(x, episode_rewards,1)
                poly1d_fn = np.poly1d(coef)
                ax2.plot(episode_rewards)
                ax2.plot(poly1d_fn(list(range(len(episode_rewards)))))
                fig.show()
                fig.savefig(save_path + "/performance.png", bbox_inches="tight")
            except:
                print("Failed saving figure")
Exemplo n.º 4
0
def main(config):
    env_name = config['run']['env']
    env = gym.make(env_name)
    np.random.seed(config['random_seed'])
    tf.set_random_seed(config['random_seed'])
    env.seed(config['random_seed'])

    batch_size = config['train']['batch_size']
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    # Define action boundaries for continuous but bounded action space
    action_bound = env.action_space.high

    print(f'-------- {env_name} --------')
    print('ACTION SPACE: ', action_dim)
    print('ACTION BOUND: ', action_bound)
    print('STATE SPACE: ', state_dim)
    print(f'------------------------')


    # TODO (20190831, JP): add normalization for envs that require it.
    # Ensure action bound is symmetric - important
    assert (env.action_space.high == -env.action_space.low)

    # Use agent_factory to build the agent using the algorithm specified in the config file.
    Agent = agent_factory(config['agent']['model'])
    agent = Agent(config, state_dim, action_dim, action_bound)

    # Initialize memory for experience replay
    replay_buffer = ReplayBuffer(config['train']['replay_buffer_size'], config['random_seed'])

    print(replay_buffer)
        
    # Set up summary TF operations
    summary_ops, summary_vars = build_summaries()

    with tf.Session() as sess:
        sess.run(tf.compat.v1.global_variables_initializer())
        writer = tf.summary.FileWriter(config['output']['summary_dir'], sess.graph)

        # Initialize target network weights
        agent.update_target_networks(sess)

        for i in range(int(config['train']['max_episodes'])):
            s = env.reset()

            episode_reward = 0
            episode_average_max_q = 0

            for j in range(int(config['train']['max_episode_len'])):
                if config['run']['render_env'] == True:
                    env.render()

                # 1. Predict an action to take
                a = agent.actor_predict_action(np.reshape(s, (1, state_dim)), sess)

                # 2. Use action to take step in environment and receive next step, reward, etc.
                s2, r, terminal, info = env.step(a[0])

                # 3. Update the replay buffer with the most recent experience
                replay_buffer.add(np.reshape(s, (state_dim,)), np.reshape(a, (action_dim,)), r,
                                  np.reshape(s2, (state_dim,)), terminal)

                # 4. When there are enough experiences in the replay buffer, sample minibatches of training experiences
                if replay_buffer.size() > batch_size:
                    s_batch, a_batch, r_batch, s2_batch, t_batch = replay_buffer.sample_batch(batch_size)

                    # Train current behavioural networks
                    predicted_Q_value = agent.train_networks(s_batch, a_batch, r_batch, s2_batch, t_batch, sess)

                    # Update for logging
                    episode_average_max_q += np.amax(predicted_Q_value)

                    # Update target networks
                    agent.update_target_networks(sess)

                # Update information for next step
                s = s2
                episode_reward += r

                # TODO (20190815, JP): as this could be different for each agent, do
                # agent.summarize_episode(summary_ops, summary_vars, episode_reward, sess) for when each agent requires own summaries?
                if terminal:
                    summary_str = sess.run(summary_ops, feed_dict={
                        summary_vars[0]: episode_reward,
                        summary_vars[1]: episode_average_max_q / float(j)
                        })

                    writer.add_summary(summary_str, i)
                    writer.flush()

                    print('| Reward: {:d} | Episode: {:d} | Qmax: {:.4f}'.format(int(episode_reward), i, (episode_average_max_q / float(j))))
                    
                    break

    if config['run']['use_gym_monitor'] == True:
        env.monitor.close()
Exemplo n.º 5
0
def train(target_vars, saver, sess, logger, dataloader, resume_iter, logdir):
    X = target_vars['X']
    X_NOISE = target_vars['X_NOISE']
    train_op = target_vars['train_op']
    energy_pos = target_vars['energy_pos']
    energy_neg = target_vars['energy_neg']
    loss_energy = target_vars['loss_energy']
    loss_ml = target_vars['loss_ml']
    loss_total = target_vars['total_loss']
    gvs = target_vars['gvs']
    x_off = target_vars['x_off']
    x_grad = target_vars['x_grad']
    x_mod = target_vars['x_mod']
    LABEL = target_vars['LABEL']
    HIER_LABEL = target_vars['HIER_LABEL']
    LABEL_POS = target_vars['LABEL_POS']
    eps = target_vars['eps_begin']
    ATTENTION_MASK = target_vars['ATTENTION_MASK']
    attention_mask = target_vars['attention_mask']
    attention_grad = target_vars['attention_grad']

    if FLAGS.prelearn_model or FLAGS.prelearn_model_shape:
        models_pretrain = target_vars['models_pretrain']

    if not FLAGS.comb_mask:
        attention_mask = tf.zeros(1)
        attention_grad = tf.zeros(1)

    if FLAGS.use_attention:
        gamma = weights['atten']['gamma']
    else:
        gamma = tf.zeros(1)

    gvs_dict = dict(gvs)

    log_output = [
        train_op, energy_pos, energy_neg, eps, loss_energy, loss_ml,
        loss_total, x_grad, x_off, x_mod, attention_mask, attention_grad,
        *gvs_dict.keys()
    ]
    output = [train_op, x_mod]
    print("log_output ", log_output)

    replay_buffer = ReplayBuffer(10000)
    itr = resume_iter
    x_mod = None
    gd_steps = 1

    dataloader_iterator = iter(dataloader)
    best_inception = 0.0

    for epoch in range(FLAGS.epoch_num):
        for data_corrupt, data, label in dataloader:
            data_corrupt = data_corrupt_init = data_corrupt.numpy()
            data_corrupt_init = data_corrupt.copy()

            data = data.numpy()

            if FLAGS.mixup:
                idx = np.random.permutation(data.shape[0])
                lam = np.random.beta(1, 1, size=(data.shape[0], 1, 1, 1))
                data = data * lam + data[idx] * (1 - lam)

            if FLAGS.replay_batch and (
                    x_mod is not None) and not FLAGS.joint_baseline:
                replay_buffer.add(compress_x_mod(x_mod))

                if len(replay_buffer) > FLAGS.batch_size:
                    replay_batch = replay_buffer.sample(FLAGS.batch_size)
                    replay_batch = decompress_x_mod(replay_batch)
                    replay_mask = (np.random.uniform(
                        0, FLAGS.rescale, FLAGS.batch_size) > FLAGS.keep_ratio)
                    data_corrupt[replay_mask] = replay_batch[replay_mask]

            if FLAGS.pcd:
                if x_mod is not None:
                    data_corrupt = x_mod

            attention_mask = np.random.uniform(
                -1., 1., (data.shape[0], 64, 64, int(FLAGS.cond_func)))
            feed_dict = {
                X_NOISE: data_corrupt,
                X: data,
                ATTENTION_MASK: attention_mask
            }

            if FLAGS.joint_baseline:
                feed_dict[target_vars['NOISE']] = np.random.uniform(
                    -1., 1., (data.shape[0], 128))

            if FLAGS.prelearn_model or FLAGS.prelearn_model_shape:
                _, _, labels = zip(*models_pretrain)
                labels = [LABEL, LABEL_POS] + list(labels)
                for lp, l in zip(labels, label):
                    # print("lp, l ", lp, l)
                    # print("l shape ", l.shape)
                    feed_dict[lp] = l
            else:
                label = label.numpy()
                label_init = label.copy()
                if FLAGS.cclass:
                    feed_dict[LABEL] = label
                    feed_dict[LABEL_POS] = label_init

            if FLAGS.heir_mask:
                feed_dict[HIER_LABEL] = label

            if itr % FLAGS.log_interval == 0:
                # print(feed_dict.keys())
                # print(feed_dict)
                _, e_pos, e_neg, eps, loss_e, loss_ml, loss_total, x_grad, x_off, x_mod, attention_mask, attention_grad, * \
                    grads = sess.run(log_output, feed_dict)

                kvs = {}
                kvs['e_pos'] = e_pos.mean()
                kvs['e_pos_std'] = e_pos.std()
                kvs['e_neg'] = e_neg.mean()
                kvs['e_diff'] = kvs['e_pos'] - kvs['e_neg']
                kvs['e_neg_std'] = e_neg.std()
                kvs['loss_e'] = loss_e.mean()
                kvs['loss_ml'] = loss_ml.mean()
                kvs['loss_total'] = loss_total.mean()
                kvs['x_grad'] = np.abs(x_grad).mean()
                kvs['attention_grad'] = np.abs(attention_grad).mean()
                kvs['x_off'] = x_off.mean()
                kvs['iter'] = itr

                for v, k in zip(grads, [v.name for v in gvs_dict.values()]):
                    kvs[k] = np.abs(v).max()

                string = "Obtained a total of "
                for key, value in kvs.items():
                    string += "{}: {}, ".format(key, value)

                if kvs['e_diff'] < -0.5:
                    print("Training is unstable")
                    assert False

                print(string)
                logger.writekvs(kvs)
            else:
                _, x_mod = sess.run(output, feed_dict)

            if itr % FLAGS.save_interval == 0:
                saver.save(
                    sess,
                    osp.join(FLAGS.logdir, FLAGS.exp, 'model_{}'.format(itr)))

            if itr > 30000:
                assert False

            # For some reason conditioning on position fails earlier
            # if FLAGS.cond_pos and itr > 30000:
            #     assert False

            if itr % FLAGS.test_interval == 0 and not FLAGS.joint_baseline and FLAGS.dataset != 'celeba':
                try_im = x_mod
                orig_im = data_corrupt.squeeze()
                actual_im = rescale_im(data)

                if not FLAGS.comb_mask:
                    attention_mask = np.random.uniform(
                        -1., 1., (data.shape[0], 64, 64, int(FLAGS.cond_func)))

                orig_im = rescale_im(orig_im)
                try_im = rescale_im(try_im).squeeze()
                attention_mask = rescale_im(attention_mask)

                for i, (im, t_im, actual_im_i, attention_im) in enumerate(
                        zip(orig_im[:20], try_im[:20], actual_im,
                            attention_mask)):
                    im, t_im, actual_im_i, attention_im = im[::
                                                             -1], t_im[::
                                                                       -1], actual_im_i[::
                                                                                        -1], attention_im[::
                                                                                                          -1]
                    shape = orig_im.shape[1:]
                    new_im = np.zeros(
                        (shape[0], shape[1] * (3 + FLAGS.cond_func),
                         *shape[2:]))
                    size = shape[1]
                    new_im[:, :size] = im
                    new_im[:, size:2 * size] = t_im
                    new_im[:, 2 * size:3 * size] = actual_im_i

                    for i in range(FLAGS.cond_func):
                        new_im[:, (3 + i) * size:(4 + i) * size] = np.tile(
                            attention_im[:, :, i:i + 1], (1, 1, 3))

                    log_image(new_im,
                              logger,
                              'train_gen_{}'.format(itr),
                              step=i)

                test_im = x_mod

                try:
                    data_corrupt, data, label = next(dataloader_iterator)
                except BaseException:
                    dataloader_iterator = iter(dataloader)
                    data_corrupt, data, label = next(dataloader_iterator)

                data_corrupt = data_corrupt.numpy()

            itr += 1

    saver.save(sess, osp.join(FLAGS.logdir, FLAGS.exp, 'model_{}'.format(itr)))
Exemplo n.º 6
0
class MADDPG:
    def __init__(self, state_size, action_size, seed = 42):
        super(MADDPG, self).__init__()

        self.agents = [Agent(state_size, action_size, lr_actor=LR_ACTOR, lr_critic=LR_CRITIC, agent_number=0, epsilon=EPSILON,
                             epsilon_decay=EPSILON_DECAY, weight_decay=WEIGHT_DECAY, clipgrad=CLIPGRAD), 
                       Agent(state_size, action_size, lr_actor=LR_ACTOR, lr_critic=LR_CRITIC, agent_number=1, epsilon=EPSILON,
                             epsilon_decay=EPSILON_DECAY, weight_decay=WEIGHT_DECAY, clipgrad=CLIPGRAD)]
        
        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
        
        # Init tracking of params
        wandb.login()
        wandb.init(project=project_name, name=name, config={"buffer_size": BUFFER_SIZE,
                                                          "batch_size": BATCH_SIZE,
                                                          "learn_every": LEARN_EVERY,
                                                          "learn_number": LEARN_NUMBER,
                                                          "lr_actor": LR_ACTOR,
                                                          "lr_critic": LR_CRITIC,
                                                          "gamma": GAMMA,
                                                          "tau": TAU,
                                                          "epsilon": EPSILON,
                                                          "epsilon_decay": EPSILON_DECAY,
                                                          "weight_decay": WEIGHT_DECAY,
                                                          "clipgrad": CLIPGRAD})
        
        jovian.log_hyperparams(project=project_name, name=name, config={"buffer_size": BUFFER_SIZE,
                                                          "batch_size": BATCH_SIZE,
                                                          "learn_every": LEARN_EVERY,
                                                          "learn_number": LEARN_NUMBER,
                                                          "lr_actor": LR_ACTOR,
                                                          "lr_critic": LR_CRITIC,
                                                          "gamma": GAMMA,
                                                          "tau": TAU,
                                                          "epsilon": EPSILON,
                                                          "epsilon_decay": EPSILON_DECAY,
                                                          "weight_decay": WEIGHT_DECAY,
                                                          "clipgrad": CLIPGRAD})

    def act(self, observations):
        """get actions from all agents in the MADDPG object"""

        actions = [agent.act(obs) for agent, obs in zip(self.agents,observations)]
        return actions

    def step(self, states, actions, rewards, next_states, dones, timestamp):
        """Save experience in replay memory, and use random sample from buffer to learn."""
        # Save experience / reward
        for state, action, reward, next_state, done in zip(states, actions, rewards, next_states, dones):
            self.memory.add(state, action, reward, next_state, done)

            
        # Learn, if enough samples are available in memory
        if len(self.memory) > BATCH_SIZE and timestamp % LEARN_EVERY == 0:
            for agent in self.agents:
                for _ in range(LEARN_NUMBER):
                    experiences = self.memory.sample()
                    agent.learn(experiences)
                
    def save(self):
        for agent in self.agents:
            agent.save_models()
            
    def get_project_name(self):
        return project_name
    
    def get_model_name(self):
        return name
Exemplo n.º 7
0
class MADDPG:
    def __init__(self, action_size, discount_factor=0.95, tau=0.02):
        super(MADDPG, self).__init__()

        # Create the multi agent as a list of ddpg agents
        self.maddpg_agents = [AgentDDPG(24, 2, 0), AgentDDPG(24, 2, 0)]

        self.discount_factor = discount_factor
        self.tau = tau
        self.iter = 0
        self.total_reward = 0.0
        self.count = 0
        self.update_every = 1
        self.batch_size = 128
        self.agent_number = len(self.maddpg_agents)
        self.t_step = 0
        # Initialize the Replay Memory
        self.buffer_size = 1000000
        self.memory = ReplayBuffer(self.buffer_size, self.batch_size)
        self.action_size = action_size
        self.total_reward = np.zeros((1, 2))

        # Initialize the Gaussian Noise process
        self.exploration_mu = 0
        self.exploration_theta = 0.15
        self.exploration_sigma = 0.2
        self.noise = OUNoise(self.action_size, self.exploration_mu,
                             self.exploration_theta, self.exploration_sigma)

    def get_actors_local(self):
        """get actors of all the agents in the MADDPG object"""
        actors = [ddpg_agt.actor_local for ddpg_agt in self.maddpg_agents]
        return actors

    def get_actors_target(self):
        """get target_actors of all the agents in the MADDPG object"""
        target_actors = [
            ddpg_agt.actor_target for ddpg_agt in self.maddpg_agents
        ]
        return target_actors

    def acts_local(self, states, noise_factor=0.0):
        """get actions from all actor_local agents in the MADDPG object"""
        """agents acts based on their individual observations"""
        actions = [
            agent.act_local(state).numpy() +
            noise_factor * self.noise.sample()
            for agent, state in zip(self.maddpg_agents, states)
        ]
        return actions

    def acts_target(self, states, noise_factor=0.0):
        """get actions from all actor_target agents in the MADDPG object"""
        target_actions = [
            agent.act_target(state).numpy() +
            noise_factor * self.noise.sample()
            for agent, state in zip(self.maddpg_agents, states)
        ]
        return target_actions

    # Actor interact with the environment through the step
    def step(self, states, actions, rewards, next_states, dones):
        # Add to the total reward the reward of this time step
        self.total_reward += rewards
        # Increase your count based on the number of rewards
        # received in the episode
        self.count += 1
        # Stored experience tuple in the replay buffer
        self.memory.add(states, actions, rewards, next_states, dones)

        # Learn every update_times time steps.
        self.t_step = (self.t_step + 1) % self.update_every
        if self.t_step == 0:

            # Check to see if you have enough to produce a batch
            # and learn from it

            if len(self.memory) > self.batch_size:
                for agent_idx in range(len(self.maddpg_agents)):
                    experiences = self.memory.sample()
                    # Train the networks using the experiences
                    self.learn(experiences, agent_idx)

    def learn(self, experiences, agent_idx):
        """update the critics and actors of all the agents """

        # Reshape the experience tuples in separate arrays of states, actions
        # rewards, next_state, done
        # Your are converting every member of the tuple in a column or vector
        states = np.vstack([e.state for e in experiences if e is not None])
        actions = np.array([e.action for e in experiences
                            if e is not None]).astype(np.float32).reshape(
                                -1, self.action_size)
        rewards = np.array([e.reward for e in experiences if e is not None
                            ]).astype(np.float32).reshape(-1, 1)
        dones = np.array([e.done for e in experiences
                          if e is not None]).astype(np.uint8).reshape(-1, 1)
        next_states = np.vstack(
            [e.next_state for e in experiences if e is not None])

        # Convert the numpy arrays to tensors
        states = torch.from_numpy(states).float().unsqueeze(0).to(device)
        actions = torch.from_numpy(actions).float().unsqueeze(0).to(device)
        next_states = torch.from_numpy(next_states).float().unsqueeze(0).to(
            device)

        # Get the agent to train
        agent = self.maddpg_agents[agent_idx]

        # Using the buffer get ALL the actor_target actions based on ALL
        # the agents next states. The acts_target function returns a list of
        # detached tensor returned from the ddpg agent act_target
        next_state_actions = self.acts_target(next_states)
        next_state_actions = torch.from_numpy(
            np.squeeze(
                np.array(next_state_actions))).float().unsqueeze(0).to(device)

        # Using ALL the actor_targets next_state_actions based on the next states
        # and ALL the next_states, calculate the "current agent" target critic
        # q_target_next_state_action_values. The critic sees all the agents actions in
        # each of their state to calculate the q_value of the current agent.
        agent.critic_target.eval()
        with torch.no_grad():
            q_target_next_state_action_value = agent.critic_target(
                next_states, next_state_actions).detach()
        agent.critic_target.train()

        # Calculate the current agent q_targets value
        q_targets = torch.from_numpy(rewards[agent_idx] +
                                     self.discount_factor *
                                     q_target_next_state_action_value.numpy() *
                                     (1 - dones[agent_idx]))

        # --- Optimize the current critic_local agent --- #

        # Using ALL the agents actions and ALL states where they took
        # you calculate the critic_local q_expected values
        q_expected = agent.critic_local(states, actions)

        # Set the grad to zero, for the critic agent optimizer selected before
        agent.critic_optimizer.zero_grad()

        # Calculate the loss function MSE
        critic_loss = F.smooth_l1_loss(q_expected, q_targets)
        critic_loss.backward(retain_graph=True)

        # Clip the gradient to improve training
        torch.nn.utils.clip_grad_norm_(agent.critic_local.parameters(), 1)

        # Optimize the critic_local model using the optimizer defined
        # in its init function
        agent.critic_optimizer.step()

        # --- Optimize the current actor_local model --- #

        # Get the current actor_local actions using only the corresponding
        # state, and for other actor states you detach the tensor from the
        # calculation to save computation time
        actor_actions = [
            self.maddpg_agents[i].actor_local(state) if i == agent_idx else
            self.maddpg_agents[i].actor_local(state).detach()
            for i, state in enumerate(states)
        ][0]

        loss_actor = -1 * torch.sum(
            agent.critic_local.forward(states, actor_actions))

        # Set the grad to zero, for the actor local current agent optimization
        agent.actor_optimizer.zero_grad()
        loss_actor.backward()

        # Optimize the actor_local current agent
        agent.actor_optimizer.step()

        # Soft update the target and local models
        self.soft_update(agent.critic_local, agent.critic_target)
        self.soft_update(agent.actor_local, agent.actor_target)

    def soft_update(self, local_model, target_model):
        # Soft update targets
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(self.tau * local_param.data +
                                    (1.0 - self.tau) * target_param.data)

    def save_models_weights(self):
        for idx, agent in enumerate(self.maddpg_agents):
            file_name = "./checkpoints_agent_" + str(idx) + ".pkl"
            torch.save(agent.actor_local.state_dict(), file_name)
class MADDPGAgent():
    def __init__(self, config, file_prefix=None):
        self.buffer_size = config.hyperparameters.buffer_size
        self.batch_size = config.hyperparameters.batch_size
        self.update_frequency = config.hyperparameters.update_frequency
        self.gamma = config.hyperparameters.gamma
        self.number_of_agents = config.environment.number_of_agents
        self.noise_weight = config.hyperparameters.noise_start
        self.noise_decay = config.hyperparameters.noise_decay
        self.memory = ReplayBuffer(config)
        self.t = 0

        self.agents = [
            DDPGAgent(index, config) for index in range(self.number_of_agents)
        ]

        if file_prefix:
            for i, to_load in enumerate(self.agents):
                f"{os.getcwd()}/models/by_score/{file_prefix}_actor_{i}.weights"
                actor_file = torch.load(
                    f"{os.getcwd()}/models/{file_prefix}_actor_{i}.weights",
                    map_location='cpu')
                critic_file = torch.load(
                    f"{os.getcwd()}/models/{file_prefix}_critic_{i}.weights",
                    map_location='cpu')
                to_load.actor_local.load_state_dict(actor_file)
                to_load.actor_target.load_state_dict(actor_file)
                to_load.critic_local.load_state_dict(critic_file)
                to_load.critic_target.load_state_dict(critic_file)
            print(f'Files loaded with prefix {file_prefix}')

    def step(self, all_states, all_actions, all_rewards, all_next_states,
             all_dones):
        all_states = all_states.reshape(1, -1)
        all_next_states = all_next_states.reshape(1, -1)
        self.memory.add(all_states, all_actions, all_rewards, all_next_states,
                        all_dones)
        self.t = (self.t + 1) % self.update_frequency
        if self.t == 0 and (len(self.memory) > self.batch_size):
            experiences = [
                self.memory.sample() for _ in range(self.number_of_agents)
            ]
            self.learn(experiences, self.gamma)

    def act(self, all_states, add_noise=True, random=0.0):
        all_actions = []
        for agent, state in zip(self.agents, all_states):
            action = agent.act(state,
                               noise=self.noise_weight if add_noise else 0.0,
                               random=random)
            self.noise_weight *= self.noise_decay
            all_actions.append(action)
        return np.array(all_actions).reshape(1, -1)

    def learn(self, experiences, gamma):
        all_actions = []
        all_next_actions = []
        for i, agent in enumerate(self.agents):
            states, _, _, next_states, _ = experiences[i]
            agent_id = torch.tensor([i]).to(device)
            state = states.reshape(-1, 2, 24).index_select(1,
                                                           agent_id).squeeze(1)
            next_state = next_states.reshape(-1, 2, 24).index_select(
                1, agent_id).squeeze(1)
            all_actions.append(agent.actor_local(state))
            all_next_actions.append(agent.actor_target(next_state))
        for i, agent in enumerate(self.agents):
            agent.learn(i, experiences[i], gamma, all_next_actions,
                        all_actions)
Exemplo n.º 9
0
class Experiment(object):
    def __init__(self,
                 domain,
                 train_data_file,
                 validation_data_file,
                 test_data_file,
                 minibatch_size,
                 rng,
                 device,
                 behav_policy_file_wDemo,
                 behav_policy_file,
                 context_input=False,
                 context_dim=0,
                 drop_smaller_than_minibatch=True,
                 folder_name='/Name',
                 autoencoder_saving_period=20,
                 resume=False,
                 sided_Q='negative',
                 autoencoder_num_epochs=50,
                 autoencoder_lr=0.001,
                 autoencoder='AIS',
                 hidden_size=16,
                 ais_gen_model=1,
                 ais_pred_model=1,
                 embedding_dim=4,
                 state_dim=42,
                 num_actions=25,
                 corr_coeff_param=10,
                 dst_hypers={},
                 cde_hypers={},
                 odernn_hypers={},
                 **kwargs):
        '''
        We assume discrete actions and scalar rewards!
        '''

        self.rng = rng
        self.device = device
        self.train_data_file = train_data_file
        self.validation_data_file = validation_data_file
        self.test_data_file = test_data_file
        self.minibatch_size = minibatch_size
        self.drop_smaller_than_minibatch = drop_smaller_than_minibatch
        self.autoencoder_num_epochs = autoencoder_num_epochs
        self.autoencoder = autoencoder
        self.autoencoder_lr = autoencoder_lr
        self.saving_period = autoencoder_saving_period
        self.resume = resume
        self.sided_Q = sided_Q
        self.num_actions = num_actions
        self.state_dim = state_dim
        self.corr_coeff_param = corr_coeff_param

        self.context_input = context_input  # Check to see if we'll one-hot encode the categorical contextual input
        self.context_dim = context_dim  # Check to see if we'll remove the context from the input and only use it for decoding
        self.hidden_size = hidden_size

        if self.context_input:
            self.input_dim = self.state_dim + self.context_dim + self.num_actions
        else:
            self.input_dim = self.state_dim + self.num_actions

        self.autoencoder_lower = self.autoencoder.lower()
        self.data_folder = folder_name + f'/{self.autoencoder_lower}_data'
        self.checkpoint_file = folder_name + f'/{self.autoencoder_lower}_checkpoints/checkpoint.pt'
        if not os.path.exists(folder_name +
                              f'/{self.autoencoder_lower}_checkpoints'):
            os.mkdir(folder_name + f'/{self.autoencoder_lower}_checkpoints')
        if not os.path.exists(folder_name + f'/{self.autoencoder_lower}_data'):
            os.mkdir(folder_name + f'/{self.autoencoder_lower}_data')
        self.store_path = folder_name
        self.gen_file = folder_name + f'/{self.autoencoder_lower}_data/{self.autoencoder_lower}_gen.pt'
        self.pred_file = folder_name + f'/{self.autoencoder_lower}_data/{self.autoencoder_lower}_pred.pt'

        if self.autoencoder == 'AIS':
            self.container = AIS.ModelContainer(device, ais_gen_model,
                                                ais_pred_model)
            self.gen = self.container.make_encoder(
                self.hidden_size,
                self.state_dim,
                self.num_actions,
                context_input=self.context_input,
                context_dim=self.context_dim)
            self.pred = self.container.make_decoder(self.hidden_size,
                                                    self.state_dim,
                                                    self.num_actions)

        elif self.autoencoder == 'AE':
            self.container = AE.ModelContainer(device)
            self.gen = self.container.make_encoder(
                self.hidden_size,
                self.state_dim,
                self.num_actions,
                context_input=self.context_input,
                context_dim=self.context_dim)
            self.pred = self.container.make_decoder(self.hidden_size,
                                                    self.state_dim,
                                                    self.num_actions)

        elif self.autoencoder == 'DST':
            self.dst_hypers = dst_hypers
            self.container = DST.ModelContainer(device)
            self.gen = self.container.make_encoder(
                self.input_dim,
                self.hidden_size,
                gru_n_layers=self.dst_hypers['gru_n_layers'],
                augment_chs=self.dst_hypers['augment_chs'])
            self.pred = self.container.make_decoder(
                self.hidden_size, self.state_dim,
                self.dst_hypers['decoder_hidden_units'])

        elif self.autoencoder == 'DDM':
            self.container = DDM.ModelContainer(device)

            self.gen = self.container.make_encoder(
                self.state_dim,
                self.hidden_size,
                context_input=self.context_input,
                context_dim=self.context_dim)
            self.pred = self.container.make_decoder(self.state_dim,
                                                    self.hidden_size)
            self.dyn = self.container.make_dyn(self.num_actions,
                                               self.hidden_size)
            self.all_params = chain(self.gen.parameters(),
                                    self.pred.parameters(),
                                    self.dyn.parameters())

            self.inv_loss_coef = 10
            self.dec_loss_coef = 0.1
            self.max_grad_norm = 50

            self.dyn_file = folder_name + '/ddm_data/ddm_dyn.pt'

        elif self.autoencoder == 'RNN':
            self.container = RNN.ModelContainer(device)

            self.gen = self.container.make_encoder(
                self.hidden_size,
                self.state_dim,
                self.num_actions,
                context_input=self.context_input,
                context_dim=self.context_dim)
            self.pred = self.container.make_decoder(self.hidden_size,
                                                    self.state_dim,
                                                    self.num_actions)

        elif self.autoencoder == 'CDE':
            self.cde_hypers = cde_hypers

            self.container = CDE.ModelContainer(device)
            self.gen = self.container.make_encoder(
                self.input_dim + 1,
                self.hidden_size,
                hidden_hidden_channels=self.
                cde_hypers['encoder_hidden_hidden_channels'],
                num_hidden_layers=self.cde_hypers['encoder_num_hidden_layers'])
            self.pred = self.container.make_decoder(
                self.hidden_size, self.state_dim,
                self.cde_hypers['decoder_num_layers'],
                self.cde_hypers['decoder_num_units'])

        elif self.autoencoder == 'ODERNN':
            self.odernn_hypers = odernn_hypers
            self.container = ODERNN.ModelContainer(device)

            self.gen = self.container.make_encoder(self.input_dim,
                                                   self.hidden_size,
                                                   self.odernn_hypers)
            self.pred = self.container.make_decoder(
                self.hidden_size, self.state_dim,
                self.odernn_hypers['decoder_n_layers'],
                self.odernn_hypers['decoder_n_units'])
        else:
            raise NotImplementedError

        self.buffer_save_file = self.data_folder + '/ReplayBuffer'
        self.next_obs_pred_errors_file = self.data_folder + '/test_next_obs_pred_errors.pt'
        self.test_representations_file = self.data_folder + '/test_representations.pt'
        self.test_correlations_file = self.data_folder + '/test_correlations.pt'
        self.policy_eval_save_file = self.data_folder + '/dBCQ_policy_eval'
        self.policy_save_file = self.data_folder + '/dBCQ_policy'
        self.behav_policy_file_wDemo = behav_policy_file_wDemo
        self.behav_policy_file = behav_policy_file

        # Read in the data csv files
        assert (domain == 'sepsis')
        self.train_demog, self.train_states, self.train_interventions, self.train_lengths, self.train_times, self.acuities, self.rewards = torch.load(
            self.train_data_file)
        train_idx = torch.arange(self.train_demog.shape[0])
        self.train_dataset = TensorDataset(self.train_demog, self.train_states,
                                           self.train_interventions,
                                           self.train_lengths,
                                           self.train_times, self.acuities,
                                           self.rewards, train_idx)

        self.train_loader = DataLoader(self.train_dataset,
                                       batch_size=self.minibatch_size,
                                       shuffle=True)

        self.val_demog, self.val_states, self.val_interventions, self.val_lengths, self.val_times, self.val_acuities, self.val_rewards = torch.load(
            self.validation_data_file)
        val_idx = torch.arange(self.val_demog.shape[0])
        self.val_dataset = TensorDataset(self.val_demog, self.val_states,
                                         self.val_interventions,
                                         self.val_lengths, self.val_times,
                                         self.val_acuities, self.val_rewards,
                                         val_idx)

        self.val_loader = DataLoader(self.val_dataset,
                                     batch_size=self.minibatch_size,
                                     shuffle=False)

        self.test_demog, self.test_states, self.test_interventions, self.test_lengths, self.test_times, self.test_acuities, self.test_rewards = torch.load(
            self.test_data_file)
        test_idx = torch.arange(self.test_demog.shape[0])
        self.test_dataset = TensorDataset(self.test_demog, self.test_states,
                                          self.test_interventions,
                                          self.test_lengths, self.test_times,
                                          self.test_acuities,
                                          self.test_rewards, test_idx)

        self.test_loader = DataLoader(self.test_dataset,
                                      batch_size=self.minibatch_size,
                                      shuffle=False)

        # encode CDE data first to save time
        if self.autoencoder == 'CDE':
            self.train_coefs = load_cde_data('train', self.train_dataset,
                                             self.cde_hypers['coefs_folder'],
                                             self.context_input, device)
            self.val_coefs = load_cde_data('val', self.val_dataset,
                                           self.cde_hypers['coefs_folder'],
                                           self.context_input, device)
            self.test_coefs = load_cde_data('test', self.test_dataset,
                                            self.cde_hypers['coefs_folder'],
                                            self.context_input, device)

    def load_model_from_checkpoint(self, checkpoint_file_path):
        checkpoint = torch.load(checkpoint_file_path)
        self.gen.load_state_dict(checkpoint['{}_gen_state_dict'.format(
            self.autoencoder.lower())])
        self.pred.load_state_dict(checkpoint['{}_pred_state_dict'.format(
            self.autoencoder.lower())])
        if self.autoencoder == 'DDM':
            self.dyn.load_state_dict(checkpoint['{}_dyn_state_dict'.format(
                self.autoencoder.lower())])
        print("Experiment: generator and predictor models loaded.")

    def train_autoencoder(self):
        print('Experiment: training autoencoder')
        device = self.device

        if self.autoencoder != 'DDM':
            self.optimizer = torch.optim.Adam(list(self.gen.parameters()) +
                                              list(self.pred.parameters()),
                                              lr=self.autoencoder_lr,
                                              amsgrad=True)
        else:
            self.optimizer = torch.optim.Adam(list(self.gen.parameters()) +
                                              list(self.pred.parameters()) +
                                              list(self.dyn.parameters()),
                                              lr=self.autoencoder_lr,
                                              amsgrad=True)

        self.autoencoding_losses = []
        self.autoencoding_losses_validation = []

        if self.resume:  # Need to rebuild this to resume training for 400 additional epochs if feasible...
            try:
                checkpoint = torch.load(self.checkpoint_file)
                self.gen.load_state_dict(checkpoint['gen_state_dict'])
                self.pred.load_state_dict(checkpoint['pred_state_dict'])
                if self.autoencoder == 'DDM':
                    self.dyn.load_state_dict(checkpoint['dyn_state_dict'])

                self.optimizer.load_state_dict(
                    checkpoint['optimizer_state_dict'])

                epoch_0 = checkpoint['epoch'] + 1
                self.autoencoding_losses = checkpoint['loss']
                self.autoencoding_losses_validation = checkpoint[
                    'validation_loss']
                print(
                    'Starting from epoch: {0} and continuing up to epoch {1}'.
                    format(epoch_0, self.autoencoder_num_epochs))
            except:
                epoch_0 = 0
                print(
                    'Error loading file, training from default setting. epoch_0 = 0'
                )
        else:
            epoch_0 = 0

        for epoch in range(epoch_0, self.autoencoder_num_epochs):
            epoch_loss = []
            print(
                "Experiment: autoencoder {0}: training Epoch = ".format(
                    self.autoencoder), epoch + 1, 'out of',
                self.autoencoder_num_epochs, 'epochs')

            # Loop through the data using the data loader
            for ii, (dem, ob, ac, l, t, scores, rewards,
                     idx) in enumerate(self.train_loader):
                # print("Batch {}".format(ii),end='')
                dem = dem.to(
                    device
                )  # 5 dimensional vector (Gender, Ventilation status, Re-admission status, Age, Weight)
                ob = ob.to(
                    device)  # 33 dimensional vector (time varying measures)
                ac = ac.to(device)
                l = l.to(device)
                t = t.to(device)
                scores = scores.to(device)
                idx = idx.to(device)
                loss_pred = 0

                # Cut tensors down to the batch's largest sequence length... Trying to speed things up a bit...
                max_length = int(l.max().item())

                # The following losses are for DDM and will not be modified by any other approach
                train_loss, dec_loss, inv_loss = 0, 0, 0
                model_loss, recon_loss, forward_loss = 0, 0, 0

                self.gen.train()
                self.pred.train()
                ob = ob[:, :max_length, :]
                dem = dem[:, :max_length, :]
                ac = ac[:, :max_length, :]
                scores = scores[:, :max_length, :]

                if self.autoencoder == 'CDE':
                    loss_pred, mse_loss, _ = self.container.loop(
                        ob,
                        dem,
                        ac,
                        scores,
                        l,
                        max_length,
                        self.context_input,
                        corr_coeff_param=self.corr_coeff_param,
                        device=device,
                        coefs=self.train_coefs,
                        idx=idx)
                else:
                    loss_pred, mse_loss, _ = self.container.loop(
                        ob,
                        dem,
                        ac,
                        scores,
                        l,
                        max_length,
                        self.context_input,
                        corr_coeff_param=self.corr_coeff_param,
                        device=device,
                        autoencoder=self.autoencoder)

                self.optimizer.zero_grad()

                if self.autoencoder != 'DDM':
                    loss_pred.backward()
                    self.optimizer.step()
                    epoch_loss.append(loss_pred.detach().cpu().numpy())
                else:
                    train_loss, dec_loss, inv_loss, model_loss, recon_loss, forward_loss, corr_loss, loss_pred = loss_pred
                    train_loss = forward_loss + self.inv_loss_coef * inv_loss + self.dec_loss_coef * dec_loss - self.corr_coeff_param * corr_loss.sum(
                    )
                    train_loss.backward()
                    torch.nn.utils.clip_grad_norm(self.all_params,
                                                  self.max_grad_norm)
                    self.optimizer.step()
                    epoch_loss.append(loss_pred.detach().cpu().numpy())

            self.autoencoding_losses.append(epoch_loss)
            if (
                    epoch + 1
            ) % self.saving_period == 0:  # Run validation and also save checkpoint

                #Computing validation loss
                epoch_validation_loss = []
                with torch.no_grad():
                    for jj, (dem, ob, ac, l, t, scores, rewards,
                             idx) in enumerate(self.val_loader):

                        dem = dem.to(device)
                        ob = ob.to(device)
                        ac = ac.to(device)
                        l = l.to(device)
                        t = t.to(device)
                        idx = idx.to(device)
                        scores = scores.to(device)
                        loss_val = 0

                        # Cut tensors down to the batch's largest sequence length... Trying to speed things up a bit...
                        max_length = int(l.max().item())

                        ob = ob[:, :max_length, :]
                        dem = dem[:, :max_length, :]
                        ac = ac[:, :max_length, :]
                        scores = scores[:, :max_length, :]

                        self.gen.eval()
                        self.pred.eval()

                        if self.autoencoder == 'CDE':
                            loss_val, mse_loss, _ = self.container.loop(
                                ob,
                                dem,
                                ac,
                                scores,
                                l,
                                max_length,
                                corr_coeff_param=0,
                                device=device,
                                coefs=self.val_coefs,
                                idx=idx)
                        else:
                            loss_val, mse_loss, _ = self.container.loop(
                                ob,
                                dem,
                                ac,
                                scores,
                                l,
                                max_length,
                                self.context_input,
                                corr_coeff_param=0,
                                device=device,
                                autoencoder=self.autoencoder)

                        if self.autoencoder in ['DST', 'ODERNN', 'CDE']:
                            epoch_validation_loss.append(mse_loss)
                        elif self.autoencoder == "DDM":
                            epoch_validation_loss.append(
                                loss_val[-1].detach().cpu().numpy())
                        else:
                            epoch_validation_loss.append(
                                loss_val.detach().cpu().numpy())

                self.autoencoding_losses_validation.append(
                    epoch_validation_loss)

                save_dict = {
                    'epoch': epoch,
                    'gen_state_dict': self.gen.state_dict(),
                    'pred_state_dict': self.pred.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'loss': self.autoencoding_losses,
                    'validation_loss': self.autoencoding_losses_validation
                }

                if self.autoencoder == 'DDM':
                    save_dict['dyn_state_dict'] = self.dyn.state_dict()

                try:
                    torch.save(save_dict, self.checkpoint_file)
                    # torch.save(save_dict, self.checkpoint_file[:-3] + str(epoch) +'_.pt')
                    np.save(
                        self.data_folder +
                        '/{}_losses.npy'.format(self.autoencoder.lower()),
                        np.array(self.autoencoding_losses))
                except Exception as e:
                    print(e)

                try:
                    np.save(
                        self.data_folder + '/{}_validation_losses.npy'.format(
                            self.autoencoder.lower()),
                        np.array(self.autoencoding_losses_validation))
                except Exception as e:
                    print(e)

            #Final epoch checkpoint
            try:
                save_dict = {
                    'epoch': self.autoencoder_num_epochs - 1,
                    'gen_state_dict': self.gen.state_dict(),
                    'pred_state_dict': self.pred.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'loss': self.autoencoding_losses,
                    'validation_loss': self.autoencoding_losses_validation,
                }
                if self.autoencoder == 'DDM':
                    save_dict['dyn_state_dict'] = self.dyn.state_dict()
                    torch.save(self.dyn.state_dict(), self.dyn_file)
                torch.save(self.gen.state_dict(), self.gen_file)
                torch.save(self.pred.state_dict(), self.pred_file)
                torch.save(save_dict, self.checkpoint_file)
                np.save(
                    self.data_folder +
                    '/{}_losses.npy'.format(self.autoencoder.lower()),
                    np.array(self.autoencoding_losses))
            except Exception as e:
                print(e)

    def evaluate_trained_model(self):
        '''After training, this method can be called to use the trained autoencoder to embed all the data in the representation space.
        We encode all data subsets (train, validation and test) separately and save them off as independent tuples. We then will
        also combine these subsets to populate a replay buffer to train a policy from.
        
        This method will also evaluate the decoder's ability to correctly predict the next observation from the and also will
        evaluate the trained representation's correlation with the acuity scores.
        '''

        # Initialize the replay buffer
        self.replay_buffer = ReplayBuffer(
            self.hidden_size,
            self.minibatch_size,
            350000,
            self.device,
            encoded_state=True,
            obs_state_dim=self.state_dim +
            (self.context_dim if self.context_input else 0))

        errors = []
        correlations = torch.Tensor()
        test_representations = torch.Tensor()
        print('Encoding the Training and Validataion Data.')
        ## LOOP THROUGH THE DATA
        # -----------------------------------------------
        # For Training and Validation sets (Encode the observations only, add all data to the experience replay buffer)
        # For the Test set:
        # - Encode the observations
        # - Save off the data (as test tuples and place in the experience replay buffer)
        # - Evaluate accuracy of predicting the next observation using the decoder module of the model
        # - Evaluate the correlation coefficient between the learned representations and the acuity scores
        with torch.no_grad():
            for i_set, loader in enumerate(
                [self.train_loader, self.val_loader, self.test_loader]):
                if i_set == 2:
                    print(
                        'Encoding the Test Data. Evaluating prediction accuracy. Calculating Correlation Coefficients.'
                    )
                for dem, ob, ac, l, t, scores, rewards, idx in loader:
                    dem = dem.to(self.device)
                    ob = ob.to(self.device)
                    ac = ac.to(self.device)
                    l = l.to(self.device)
                    t = t.to(self.device)
                    scores = scores.to(self.device)
                    rewards = rewards.to(self.device)

                    max_length = int(l.max().item())

                    ob = ob[:, :max_length, :]
                    dem = dem[:, :max_length, :]
                    ac = ac[:, :max_length, :]
                    scores = scores[:, :max_length, :]
                    rewards = rewards[:, :max_length]

                    cur_obs, next_obs = ob[:, :-1, :], ob[:, 1:, :]
                    cur_dem, next_dem = dem[:, :-1, :], dem[:, 1:, :]
                    cur_actions = ac[:, :-1, :]
                    cur_rewards = rewards[:, :-1]
                    cur_scores = scores[:, :-1, :]
                    mask = (cur_obs == 0).all(dim=2)

                    self.gen.eval()
                    self.pred.eval()

                    if self.autoencoder in ['AE', 'AIS', 'RNN']:

                        if self.context_input:
                            representations = self.gen(
                                torch.cat(
                                    (cur_obs, cur_dem,
                                     torch.cat((torch.zeros(
                                         (ob.shape[0], 1, ac.shape[-1])).to(
                                             self.device), ac[:, :-2, :]),
                                               dim=1)),
                                    dim=-1))
                        else:
                            representations = self.gen(
                                torch.cat(
                                    (cur_obs,
                                     torch.cat((torch.zeros(
                                         (ob.shape[0], 1, ac.shape[-1])).to(
                                             self.device), ac[:, :-2, :]),
                                               dim=1)),
                                    dim=-1))

                        if self.autoencoder == 'RNN':
                            pred_obs = self.pred(representations)
                        else:
                            pred_obs = self.pred(
                                torch.cat((representations, cur_actions),
                                          dim=-1))

                        pred_error = F.mse_loss(next_obs[~mask],
                                                pred_obs[~mask])

                    elif self.autoencoder == 'DDM':
                        # Initialize hidden states for the LSTM layer
                        cx_d = torch.zeros(1, ob.shape[0],
                                           self.hidden_size).to(self.device)
                        hx_d = torch.zeros(1, ob.shape[0],
                                           self.hidden_size).to(self.device)

                        if self.context_input:
                            representations = self.gen(
                                torch.cat((cur_obs, cur_dem), dim=-1))
                            z_prime = self.gen(
                                torch.cat((next_obs, next_dem), dim=-1))
                        else:
                            representations = self.gen(cur_obs)
                            z_prime = self.gen(next_obs)

                        s_hat = self.pred(representations)
                        z_prime_hat, a_hat, _ = self.dyn(
                            (representations, z_prime, cur_actions, (hx_d,
                                                                     cx_d)))
                        s_prime_hat = self.pred(z_prime_hat)

                        __, pred_error, __, __, __ = get_dynamics_losses(
                            cur_obs[~mask],
                            s_hat[~mask],
                            next_obs[~mask],
                            s_prime_hat[~mask],
                            z_prime[~mask],
                            z_prime_hat[~mask],
                            a_hat[~mask],
                            cur_actions[~mask],
                            discrete=False)

                    elif self.autoencoder in ['DST', 'ODERNN']:
                        _, pred_error, representations = self.container.loop(
                            ob,
                            dem,
                            ac,
                            scores,
                            l,
                            max_length,
                            self.context_input,
                            corr_coeff_param=0,
                            device=self.device)
                        representations = representations[:, :-1, :].detach(
                        )  # remove latent of last time step (with no target)

                    elif self.autoencoder == 'CDE':
                        i_coefs = (self.train_coefs, self.val_coefs,
                                   self.test_coefs)[i_set]
                        _, pred_error, representations = self.container.loop(
                            ob,
                            dem,
                            ac,
                            scores,
                            l,
                            max_length,
                            self.context_input,
                            corr_coeff_param=0,
                            device=self.device,
                            coefs=i_coefs,
                            idx=idx)
                        representations = representations[:, :-1, :].detach()

                    if i_set == 2:  # If we're evaluating the models on the test set...
                        # Compute the Pearson correlation of the learned representations and the acuity scores
                        corr = torch.zeros(
                            (cur_obs.shape[0], representations.shape[-1],
                             cur_scores.shape[-1]))
                        for i in range(cur_obs.shape[0]):
                            corr[i] = pearson_correlation(
                                representations[i][~mask[i]],
                                cur_scores[i][~mask[i]],
                                device=self.device)

                        # Concatenate this batch's correlations with the larger tensor
                        correlations = torch.cat((correlations, corr), dim=0)

                        # Concatenate the batch's representations with the larger tensor
                        test_representations = torch.cat(
                            (test_representations, representations.cpu()),
                            dim=0)

                        # Append the batch's prediction errors to the list
                        if torch.is_tensor(pred_error):
                            errors.append(pred_error.item())
                        else:
                            errors.append(pred_error)

                    # Remove values with the computed mask and add data to the experience replay buffer
                    cur_rep = torch.cat(
                        (representations[:, :-1, :],
                         torch.zeros(
                             (cur_obs.shape[0], 1, self.hidden_size)).to(
                                 self.device)),
                        dim=1)
                    next_rep = torch.cat(
                        (representations[:, 1:, :],
                         torch.zeros(
                             (cur_obs.shape[0], 1, self.hidden_size)).to(
                                 self.device)),
                        dim=1)
                    cur_rep = cur_rep[~mask].cpu()
                    next_rep = next_rep[~mask].cpu()
                    cur_actions = cur_actions[~mask].cpu()
                    cur_rewards = cur_rewards[~mask].cpu()
                    cur_obs = cur_obs[~mask].cpu(
                    )  # Need to keep track of the actual observations that were made to form the corresponding representations (for downstream WIS)
                    next_obs = next_obs[~mask].cpu()
                    cur_dem = cur_dem[~mask].cpu()
                    next_dem = next_dem[~mask].cpu()

                    # Loop over all transitions and add them to the replay buffer
                    for i_trans in range(cur_rep.shape[0]):
                        done = cur_rewards[i_trans] != 0
                        if self.context_input:
                            self.replay_buffer.add(
                                cur_rep[i_trans].numpy(),
                                cur_actions[i_trans].argmax().item(),
                                next_rep[i_trans].numpy(),
                                cur_rewards[i_trans].item(), done.item(),
                                torch.cat((cur_obs[i_trans], cur_dem[i_trans]),
                                          dim=-1).numpy(),
                                torch.cat(
                                    (next_obs[i_trans], next_dem[i_trans]),
                                    dim=-1).numpy())
                        else:
                            self.replay_buffer.add(
                                cur_rep[i_trans].numpy(),
                                cur_actions[i_trans].argmax().item(),
                                next_rep[i_trans].numpy(),
                                cur_rewards[i_trans].item(), done.item(),
                                cur_obs[i_trans].numpy(),
                                next_obs[i_trans].numpy())

            ## SAVE OFF DATA
            # --------------
            self.replay_buffer.save(self.buffer_save_file)
            torch.save(errors, self.next_obs_pred_errors_file)
            torch.save(test_representations, self.test_representations_file)
            torch.save(correlations, self.test_correlations_file)

    def train_dBCQ_policy(self, pol_learning_rate=1e-3):

        # Initialize parameters for policy learning
        params = {
            "eval_freq":
            500,
            "discount":
            0.99,
            "buffer_size":
            350000,
            "batch_size":
            self.minibatch_size,
            "optimizer":
            "Adam",
            "optimizer_parameters": {
                "lr": pol_learning_rate
            },
            "train_freq":
            1,
            "polyak_target_update":
            True,
            "target_update_freq":
            1,
            "tau":
            0.01,
            "max_timesteps":
            5e5,
            "BCQ_threshold":
            0.3,
            "buffer_dir":
            self.buffer_save_file,
            "policy_file":
            self.policy_save_file + f'_l{pol_learning_rate}.pt',
            "pol_eval_file":
            self.policy_eval_save_file + f'_l{pol_learning_rate}.npy',
        }

        # Initialize a dataloader for policy evaluation (will need representations, observations, demographics, rewards and actions from the test dataset)
        test_representations = torch.load(
            self.test_representations_file)  # Load the test representations
        pol_eval_dataset = TensorDataset(test_representations,
                                         self.test_states,
                                         self.test_interventions,
                                         self.test_demog, self.test_rewards)
        pol_eval_dataloader = DataLoader(pol_eval_dataset,
                                         batch_size=self.minibatch_size,
                                         shuffle=False)

        # Initialize and Load the experience replay buffer corresponding with the current settings of rand_num, hidden_size, etc...
        replay_buffer = ReplayBuffer(
            self.hidden_size,
            self.minibatch_size,
            350000,
            self.device,
            encoded_state=True,
            obs_state_dim=self.state_dim +
            (self.context_dim if self.context_input else 0))

        # Load the pretrained policy for whether or not the demographic context was used to train the representations
        behav_input = self.state_dim + (self.context_dim
                                        if self.context_input else 0)
        behav_pol = FC_BC(behav_input, self.num_actions, 64).to(self.device)
        if self.context_input:
            behav_pol.load_state_dict(torch.load(self.behav_policy_file_wDemo))
        else:
            behav_pol.load_state_dict(torch.load(self.behav_policy_file))
        behav_pol.eval()

        # Run dBCQ_utils.train_dBCQ
        train_dBCQ(replay_buffer, self.num_actions, self.hidden_size,
                   self.device, params, behav_pol, pol_eval_dataloader,
                   self.context_input)