num_actions = 25
    if args.dem_context:
        train_buffer_file = '/scratch/ssd001/home/tkillian/ml4h2020_srl/raw_data_buffers/train_buffer'
        validation_buffer_file = '/scratch/ssd001/home/tkillian/ml4h2020_srl/raw_data_buffers/val_buffer'
    else:
        train_buffer_file = '/scratch/ssd001/home/tkillian/ml4h2020_srl/raw_data_buffers/train_noCntxt_buffer'
        validation_buffer_file = '/scratch/ssd001/home/tkillian/ml4h2020_srl/raw_data_buffers/val_noCntxt_buffer'

    storage_dir = '/scratch/ssd001/home/tkillian/ml4h2020_srl/BehavCloning/' + args.storage_folder + '/'

    if not os.path.exists(storage_dir):
        os.mkdir(storage_dir)

    # Initialize and load the training and validation buffers to populate dataloaders
    train_buffer = ReplayBuffer(input_dim, args.batch_size, 200000, device)
    train_buffer.load(train_buffer_file)
    states = train_buffer.state[:train_buffer.crt_size]
    actions = train_buffer.action[:train_buffer.crt_size]
    train_dataset = TensorDataset(
        torch.from_numpy(states).float(),
        torch.from_numpy(actions).long())
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True)

    val_buffer = ReplayBuffer(input_dim, args.batch_size, 50000, device)
    val_buffer.load(validation_buffer_file)
    val_states = val_buffer.state[:val_buffer.crt_size]
    val_actions = val_buffer.action[:val_buffer.crt_size]
    val_dataset = TensorDataset(
        torch.from_numpy(val_states).float(),
class DDQNLearner(DDQN):
    def __init__(self,
                 env,
                 save_dirs,
                 save_freq=10000,
                 gamma=0.99,
                 batch_size=32,
                 learning_rate=0.0001,
                 buffer_size=10000,
                 learn_start=10000,
                 target_network_update_freq=1000,
                 train_freq=4,
                 epsilon_min=0.01,
                 exploration_fraction=0.1,
                 tot_steps=int(1e7)):
        DDQN.__init__(self,
                      env=env,
                      save_dirs=save_dirs,
                      learning_rate=learning_rate)

        self.gamma = gamma
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.buffer_size = buffer_size
        self.learn_start = learn_start
        self.target_network_update_freq = target_network_update_freq
        self.train_freq = train_freq
        self.epsilon_min = epsilon_min
        self.exploration_fraction = exploration_fraction
        self.tot_steps = tot_steps
        self.epsilon = 1.0
        self.exploration = LinearSchedule(schedule_timesteps=int(
            self.exploration_fraction * self.tot_steps),
                                          initial_p=self.epsilon,
                                          final_p=self.epsilon_min)

        self.save_freq = save_freq

        self.replay_buffer = ReplayBuffer(save_dirs=save_dirs,
                                          buffer_size=self.buffer_size,
                                          obs_shape=self.input_shape)

        self.exploration_factor_save_path = os.path.join(
            self.save_path, 'exploration-factor.npz')

        self.target_model_save_path = os.path.join(self.save_path,
                                                   'target-wts.h5')
        self.target_model = NeuralNet(input_shape=self.input_shape,
                                      num_actions=self.num_actions,
                                      learning_rate=learning_rate,
                                      blueprint=self.blueprint).model

        self.show_hyperparams()

        self.update_target()

        self.load()

    def update_exploration(self, t):
        self.epsilon = self.exploration.value(t)

    def update_target(self):
        self.target_model.set_weights(self.local_model.get_weights())

    def remember(self, obs, action, rew, new_obs, done):
        self.replay_buffer.add(obs, action, rew, new_obs, done)

    def step_update(self, t):
        hist = None

        if t <= self.learn_start:
            return hist
        if t % self.train_freq == 0:
            hist = self.learn()
        if t % self.target_network_update_freq == 0:
            self.update_target()
        return hist

    def act(self, obs):
        if np.random.rand() < self.epsilon:
            return self.env.action_space.sample()
        q_vals = self.local_model.predict(
            np.expand_dims(obs, axis=0).astype(float) / 255, batch_size=1)
        return np.argmax(q_vals[0])

    def learn(self):
        if self.replay_buffer.meta_data['fill_size'] < self.batch_size:
            return

        curr_obs, action, reward, next_obs, done = self.replay_buffer.get_minibatch(
            self.batch_size)
        target = self.local_model.predict(curr_obs.astype(float) / 255,
                                          batch_size=self.batch_size)

        done_mask = done.ravel()
        undone_mask = np.invert(done).ravel()

        target[done_mask,
               action[done_mask].ravel()] = reward[done_mask].ravel()

        Q_target = self.target_model.predict(next_obs.astype(float) / 255,
                                             batch_size=self.batch_size)
        Q_future = np.max(Q_target[undone_mask], axis=1)

        target[undone_mask, action[undone_mask].ravel(
        )] = reward[undone_mask].ravel() + self.gamma * Q_future

        hist = self.local_model.fit(curr_obs.astype(float) / 255,
                                    target,
                                    batch_size=self.batch_size,
                                    verbose=0).history
        return hist

    def load_mdl(self):
        super().load_mdl()
        if os.path.isfile(self.target_model_save_path):
            self.target_model.load_weights(self.target_model_save_path)
            print('Loaded Target Model...')
        else:
            print('No existing Target Model found...')

    def save_mdl(self):
        self.local_model.save_weights(self.local_model_save_path)
        print('Local Model Saved...')
        self.target_model.save_weights(self.target_model_save_path)
        print('Target Model Saved...')

    def save_exploration(self):
        np.savez(self.exploration_factor_save_path, exploration=self.epsilon)
        print('Exploration Factor Saved...')

    def load_exploration(self):
        if os.path.isfile(self.exploration_factor_save_path):
            with np.load(self.exploration_factor_save_path) as f:
                self.epsilon = np.asscalar(f['exploration'])
            print('Exploration Factor Loaded...')
        else:
            print('No existing Exploration Factor found...')

    def save(self, t, logger):
        ep = logger.data['episode']
        if (self.save_freq is not None and t > self.learn_start and ep > 100
                and t % self.save_freq == 0):
            if logger.update_best_score():
                logger.save_state()
                self.save_mdl()
                self.save_exploration()
                self.replay_buffer.save()

    def load(self):
        self.load_mdl()
        self.load_exploration()
        self.replay_buffer.load()

    def show_hyperparams(self):
        print('Discount Factor (gamma): {}'.format(self.gamma))
        print('Batch Size: {}'.format(self.batch_size))
        print('Replay Buffer Size: {}'.format(self.buffer_size))
        print('Training Frequency: {}'.format(self.train_freq))
        print('Target network update Frequency: {}'.format(
            self.target_network_update_freq))
        print('Replay start size: {}'.format(self.learn_start))