Ejemplo n.º 1
0
class RLSVIIncrementalTDAgent(Agent):
    def __init__(self,
                 action_set,
                 reward_function,
                 prior_variance,
                 noise_variance,
                 feature_extractor,
                 prior_network,
                 num_ensemble,
                 hidden_dims=[10, 10],
                 learning_rate=5e-4,
                 buffer_size=50000,
                 batch_size=64,
                 num_batches=100,
                 starts_learning=5000,
                 discount=0.99,
                 target_freq=10,
                 verbose=False,
                 print_every=1,
                 test_model_path=None):
        Agent.__init__(self, action_set, reward_function)

        self.prior_variance = prior_variance
        self.noise_variance = noise_variance

        self.feature_extractor = feature_extractor
        self.feature_dim = self.feature_extractor.dimension

        dims = [self.feature_dim] + hidden_dims + [len(self.action_set)]

        self.prior_network = prior_network
        self.num_ensemble = num_ensemble  # number of models in ensemble

        self.index = np.random.randint(self.num_ensemble)

        # build Q network
        # we use a multilayer perceptron

        if test_model_path is None:
            self.test_mode = False
            self.learning_rate = learning_rate
            self.buffer_size = buffer_size
            self.batch_size = batch_size
            self.num_batches = num_batches
            self.starts_learning = starts_learning
            self.discount = discount
            self.timestep = 0

            self.buffer = Buffer(self.buffer_size)
            self.models = []
            for i in range(self.num_ensemble):
                if self.prior_network:
                    '''
                    Second network is a prior network whose weights are fixed
                    and first network is difference network learned i.e, weights are mutable
                    '''
                    self.models.append(
                        DQNWithPrior(dims, scale=np.sqrt(
                            self.prior_variance)).to(device))
                else:
                    self.models.append(MLP(dims).to(device))
                self.models[i].initialize()
            '''
            prior networks weights are immutable so enough to keep difference network
            '''
            self.target_nets = []
            for i in range(self.num_ensemble):
                if self.prior_network:
                    self.target_nets.append(
                        DQNWithPrior(dims, scale=np.sqrt(
                            self.prior_variance)).to(device))
                else:
                    self.target_nets.append(MLP(dims).to(device))
                    self.target_nets[i].load_state_dict(
                        self.models[i].state_dict())
                    self.target_nets[i].eval()

            self.target_freq = target_freq  #   target nn updated every target_freq episodes
            self.num_episodes = 0

            self.optimizer = []
            for i in range(self.num_ensemble):
                self.optimizer.append(
                    torch.optim.Adam(self.models[i].parameters(),
                                     lr=self.learning_rate))

            # for debugging purposes
            self.verbose = verbose
            self.running_loss = 1.
            self.print_every = print_every

        else:
            self.models = []
            self.test_mode = True
            if self.prior_network:
                self.models.append(
                    DQNWithPrior(dims, scale=self.prior_variance))
            else:
                self.models.append(MLP(dims))
            self.models[0].load_state_dict(torch.load(test_model_path))
            self.models[0].eval()
            self.index = 0

    def __str__(self):
        return 'rlsvi_incremental_TD_' + str(self.num_ensemble) + 'models'

    def update_buffer(self, observation_history, action_history):
        """
        update buffer with data collected from current episode
        """
        reward_history = self.get_episode_reward(observation_history,
                                                 action_history)
        self.cummulative_reward += np.sum(reward_history)

        tau = len(action_history)
        feature_history = np.zeros((tau + 1, self.feature_extractor.dimension))
        for t in range(tau + 1):
            feature_history[t] = self.feature_extractor.get_feature(
                observation_history[:t + 1])

        for t in range(tau - 1):
            perturbations = np.random.randn(self.num_ensemble) * np.sqrt(
                self.noise_variance)
            self.buffer.add(
                (feature_history[t], action_history[t], reward_history[t],
                 feature_history[t + 1], perturbations))
        done = observation_history[tau][1]
        if done:
            feat_next = None
        else:
            feat_next = feature_history[tau]
        perturbations = np.random.randn(self.num_ensemble) * np.sqrt(
            self.noise_variance)
        self.buffer.add((feature_history[tau - 1], action_history[tau - 1],
                         reward_history[tau - 1], feat_next, perturbations))

    def learn_from_buffer(self):
        """
        update Q network by applying TD steps
        """
        if self.timestep < self.starts_learning:
            pass

        loss_ensemble = 0

        for _ in range(self.num_batches):
            for sample_num in range(self.num_ensemble):
                minibatch = self.buffer.sample(batch_size=self.batch_size)

                feature_batch = torch.zeros(self.batch_size,
                                            self.feature_dim,
                                            device=device)
                action_batch = torch.zeros(self.batch_size,
                                           1,
                                           dtype=torch.long,
                                           device=device)
                reward_batch = torch.zeros(self.batch_size, 1, device=device)
                perturb_batch = torch.zeros(self.batch_size,
                                            self.num_ensemble,
                                            device=device)
                non_terminal_idxs = []
                next_feature_batch = []

                for i, d in enumerate(minibatch):
                    s, a, r, s_next, perturb = d
                    feature_batch[i] = torch.from_numpy(s)
                    action_batch[i] = torch.tensor(a, dtype=torch.long)
                    reward_batch[i] = r
                    perturb_batch[i] = torch.from_numpy(perturb)
                    if s_next is not None:
                        non_terminal_idxs.append(i)
                        next_feature_batch.append(s_next)
                model_estimates = (
                    self.models[sample_num](feature_batch)).gather(
                        1, action_batch).float()

                future_values = torch.zeros(self.batch_size, device=device)
                if non_terminal_idxs != []:
                    next_feature_batch = torch.tensor(next_feature_batch,
                                                      dtype=torch.float,
                                                      device=device)
                    future_values[non_terminal_idxs] = (
                        self.target_nets[sample_num](next_feature_batch)
                    ).max(1)[0].detach()
                future_values = future_values.unsqueeze(1)
                temp = perturb_batch[:, sample_num].unsqueeze(1)
                target_values = reward_batch + self.discount * future_values \
                                + perturb_batch[:,sample_num].unsqueeze(1)

                assert (model_estimates.shape == target_values.shape)

                loss = nn.functional.mse_loss(model_estimates, target_values)

                self.optimizer[sample_num].zero_grad()
                loss.backward()
                self.optimizer[sample_num].step()
                loss_ensemble += loss.item()
        self.running_loss = 0.99 * self.running_loss + 0.01 * loss_ensemble

        self.num_episodes += 1

        self.index = np.random.randint(self.num_ensemble)

        if self.verbose and (self.num_episodes % self.print_every == 0):
            print("rlsvi ep %d, running loss %.2f, reward %.3f, index %d" %
                  (self.num_episodes, self.running_loss,
                   self.cummulative_reward, self.index))

        if self.num_episodes % self.target_freq == 0:
            for sample_num in range(self.num_ensemble):
                self.target_nets[sample_num].load_state_dict(
                    self.models[sample_num].state_dict())
            # if self.verbose:
            #     print("rlsvi via ensemble sampling ep %d update target network" % self.num_episodes)

    def act(self, observation_history, action_history):
        """ select action according to an epsilon greedy policy with respect to
        the Q network """
        feature = self.feature_extractor.get_feature(observation_history)
        with torch.no_grad():
            if str(device) == "cpu":
                action_values = (self.models[self.index](
                    torch.tensor(feature).float())).numpy()
            else:
                out = (self.models[self.index](
                    torch.tensor(feature).float().to(device)))
                action_values = (out.to("cpu")).numpy()

            action = self._random_argmax(action_values)
        return action

    def save(self, path=None):
        if path is None:
            path = './' + self.__str__() + '.pt'
        torch.save(self.models[self.index].state_dict(), path)
Ejemplo n.º 2
0
def main(args):

    torch.manual_seed(args.seed)

    # start simulators
    mp.set_start_method('spawn')

    episode_q = Queue()
    player_qs = []
    simulators = []
    for si in range(args.n_simulators):
        player_qs.append(Queue())
        simulators.append(
            mp.Process(target=simulator,
                       args=(
                           si,
                           player_qs[-1],
                           episode_q,
                           args,
                           False,
                       )))
        simulators[-1].start()

    return_q = Queue()
    valid_q = Queue()
    valid_simulator = mp.Process(target=simulator,
                                 args=(
                                     args.n_simulators,
                                     valid_q,
                                     return_q,
                                     args,
                                     True,
                                 ))
    valid_simulator.start()

    env = gym.make(args.env)
    # env = gym.make('Assault-ram-v0')

    n_frames = args.n_frames

    # initialize replay buffer
    replay_buffer = Buffer(max_items=args.buffer_size,
                           n_frames=n_frames,
                           priority_ratio=args.priority_ratio,
                           store_ratio=args.store_ratio)

    n_iter = args.n_iter
    init_collect = args.init_collect
    n_collect = args.n_collect
    n_value = args.n_value
    n_policy = args.n_policy
    n_hid = args.n_hid

    critic_aware = args.critic_aware

    update_every = args.update_every

    disp_iter = args.disp_iter
    val_iter = args.val_iter
    save_iter = args.save_iter

    max_len = args.max_len
    batch_size = args.batch_size
    max_collected_frames = args.max_collected_frames

    clip_coeff = args.grad_clip
    ent_coeff = args.ent_coeff
    discount_factor = args.discount_factor

    value_loss = -numpy.Inf
    entropy = -numpy.Inf
    valid_ret = -numpy.Inf
    ess = -numpy.Inf
    n_collected_frames = 0

    offset = 0

    return_history = []

    if args.nn == "ff":
        # create a policy
        player = ff.Player(n_in=128 * n_frames, n_hid=args.n_hid,
                           n_out=6).to(args.device)
        if args.player_coeff > 0.:
            player_old = ff.Player(n_in=128 * n_frames,
                                   n_hid=args.n_hid,
                                   n_out=6).to(args.device)
        player_copy = ff.Player(n_in=128 * n_frames, n_hid=args.n_hid,
                                n_out=6).to('cpu')

        # create a value estimator
        value = ff.Value(n_in=128 * n_frames, n_hid=args.n_hid).to(args.device)
        value_old = ff.Value(n_in=128 * n_frames,
                             n_hid=args.n_hid).to(args.device)

        for m in player.parameters():
            m.data.normal_(0., 0.01)
        for m in value.parameters():
            m.data.normal_(0., 0.01)
    elif args.nn == "conv":
        # create a policy
        player = conv.Player(n_frames=n_frames,
                             n_hid=args.n_hid).to(args.device)
        if args.player_coeff > 0.:
            player_old = conv.Player(n_frames=n_frames,
                                     n_hid=args.n_hid).to(args.device)
        player_copy = conv.Player(n_frames=n_frames,
                                  n_hid=args.n_hid).to('cpu')

        # create a value estimator
        value = conv.Value(n_frames, n_hid=args.n_hid).to(args.device)
        value_old = conv.Value(n_frames, n_hid=args.n_hid).to(args.device)
    else:
        raise Exception('Unknown type')

    if args.cont:
        files = glob.glob("{}*th".format(args.saveto))
        iterations = [
            int(".".join(f.split('.')[:-1]).split('_')[-1].strip())
            for f in files
        ]
        last_iter = numpy.max(iterations)
        offset = last_iter - 1
        print('Reloading from {}_{}.th'.format(args.saveto, last_iter))
        checkpoint = torch.load("{}_{}.th".format(args.saveto, last_iter))
        player.load_state_dict(checkpoint['player'])
        value.load_state_dict(checkpoint['value'])
        return_history = checkpoint['return_history']
        n_collected_frames = checkpoint['n_collected_frames']

    copy_params(value, value_old)
    if args.player_coeff > 0.:
        copy_params(player, player_old)

    # start simulators
    player.to('cpu')
    copy_params(player, player_copy)
    for si in range(args.n_simulators):
        player_qs[si].put(
            [copy.deepcopy(p.data.numpy()) for p in player_copy.parameters()] +
            [copy.deepcopy(p.data.numpy()) for p in player_copy.buffers()])
    valid_q.put(
        [copy.deepcopy(p.data.numpy()) for p in player_copy.parameters()] +
        [copy.deepcopy(p.data.numpy()) for p in player_copy.buffers()])
    player.to(args.device)

    if args.device == 'cuda':
        torch.set_num_threads(1)

    initial = True
    pre_filled = 0

    for ni in range(n_iter):
        # re-initialize optimizers
        opt_player = eval(args.optimizer_player)(player.parameters(),
                                                 lr=args.lr,
                                                 weight_decay=args.l2)
        opt_value = eval(args.optimizer_value)(value.parameters(),
                                               lr=args.lr,
                                               weight_decay=args.l2)

        try:
            if not initial:
                lr = args.lr / (1 + (ni - pre_filled + 1) * args.lr_factor)
                ent_coeff = args.ent_coeff / (
                    1 + (ni - pre_filled + 1) * args.ent_factor)
                print('lr', lr, 'ent_coeff', ent_coeff)

                for param_group in opt_player.param_groups:
                    param_group['lr'] = lr
                for param_group in opt_value.param_groups:
                    param_group['lr'] = lr

            if numpy.mod((ni - pre_filled + 1), save_iter) == 0:
                torch.save(
                    {
                        'n_iter': n_iter,
                        'n_collect': n_collect,
                        'n_value': n_value,
                        'n_policy': n_policy,
                        'max_len': max_len,
                        'n_hid': n_hid,
                        'batch_size': batch_size,
                        'player': player.state_dict(),
                        'value': value.state_dict(),
                        'return_history': return_history,
                        'n_collected_frames': n_collected_frames,
                    }, '{}_{}.th'.format(args.saveto,
                                         (ni - pre_filled + 1) + offset + 1))

            player.eval()

            ret_ = -numpy.Inf
            while True:
                try:
                    ret_ = return_q.get_nowait()
                except queue.Empty:
                    break
            if ret_ != -numpy.Inf:
                return_history.append(ret_)
                if valid_ret == -numpy.Inf:
                    valid_ret = ret_
                else:
                    valid_ret = 0.9 * valid_ret + 0.1 * ret_
                print('Valid run', ret_, valid_ret)

            #st = time.time()

            player.to('cpu')
            copy_params(player, player_copy)
            for si in range(args.n_simulators):
                while True:
                    try:
                        # empty the queue, as the new one has arrived
                        player_qs[si].get_nowait()
                    except queue.Empty:
                        break

                player_qs[si].put([
                    copy.deepcopy(p.data.numpy())
                    for p in player_copy.parameters()
                ] + [
                    copy.deepcopy(p.data.numpy())
                    for p in player_copy.buffers()
                ])
            while True:
                try:
                    # empty the queue, as the new one has arrived
                    valid_q.get_nowait()
                except queue.Empty:
                    break
            valid_q.put([
                copy.deepcopy(p.data.numpy())
                for p in player_copy.parameters()
            ] + [copy.deepcopy(p.data.numpy()) for p in player_copy.buffers()])

            player.to(args.device)

            #print('model push took', time.time()-st)

            #st = time.time()

            n_collected_frames_ = 0
            while True:
                try:
                    epi = episode_q.get_nowait()
                    replay_buffer.add(epi[0], epi[1], epi[2], epi[3])
                    n_collected_frames_ = n_collected_frames_ + len(epi[0])
                except queue.Empty:
                    break
                if n_collected_frames_ >= max_collected_frames \
                        and (len(replay_buffer.buffer) + len(replay_buffer.priority_buffer)) > 0:
                    break
            n_collected_frames = n_collected_frames + n_collected_frames_

            if len(replay_buffer.buffer) + len(
                    replay_buffer.priority_buffer) < 1:
                continue

            if len(replay_buffer.buffer) + len(
                    replay_buffer.priority_buffer) < args.initial_buffer:
                if initial:
                    print(
                        'Pre-filling the buffer...',
                        len(replay_buffer.buffer) +
                        len(replay_buffer.priority_buffer))
                    continue
            else:
                if initial:
                    pre_filled = ni
                    initial = False

            #print('collection took', time.time()-st)

            #print('Buffer size', len(replay_buffer.buffer) + len(replay_buffer.priority_buffer))

            # fit a value function
            # TD(0)
            #st = time.time()

            value.train()
            for vi in range(n_value):
                if numpy.mod(vi, update_every) == 0:
                    #print(vi, 'zeroing gradient')
                    opt_player.zero_grad()
                    opt_value.zero_grad()

                batch = replay_buffer.sample(batch_size)

                batch_x = torch.from_numpy(
                    numpy.stack([ex.current_['obs'] for ex in batch
                                 ]).astype('float32')).to(args.device)
                batch_r = torch.from_numpy(
                    numpy.stack([ex.current_['rew'] for ex in batch
                                 ]).astype('float32')).to(args.device)
                batch_xn = torch.from_numpy(
                    numpy.stack([ex.next_['obs'] for ex in batch
                                 ]).astype('float32')).to(args.device)
                pred_y = value(batch_x)
                pred_next = value_old(batch_xn).clone().detach()
                batch_pi = player(batch_x)

                loss_ = ((batch_r + discount_factor * pred_next.squeeze() -
                          pred_y.squeeze())**2)

                batch_a = torch.from_numpy(
                    numpy.stack([ex.current_['act'] for ex in batch
                                 ]).astype('float32')[:, None]).to(args.device)
                batch_q = torch.from_numpy(
                    numpy.stack([ex.current_['prob'] for ex in batch
                                 ]).astype('float32')).to(args.device)
                logp = torch.log(batch_pi.gather(1, batch_a.long()) + 1e-8)

                # (clipped) importance weight:
                # because the policy may have changed since the tuple was collected.
                log_iw = logp.squeeze().clone().detach() - torch.log(
                    batch_q.squeeze() + 1e-8)
                ess_ = torch.exp(-torch.logsumexp(2 * log_iw, dim=0)).item()
                iw = torch.exp(log_iw.clamp(max=0.))

                if args.iw:
                    loss = iw * loss_
                else:
                    loss = loss_

                loss = loss.mean()

                loss.backward()

                if numpy.mod(vi, update_every) == (update_every - 1):
                    #print(vi, 'making an update')
                    if clip_coeff > 0.:
                        nn.utils.clip_grad_norm_(value.parameters(),
                                                 clip_coeff)
                    opt_value.step()

            copy_params(value, value_old)

            if value_loss < 0.:
                value_loss = loss_.mean().item()
            else:
                value_loss = 0.9 * value_loss + 0.1 * loss_.mean().item()

            if numpy.mod((ni - pre_filled + 1), disp_iter) == 0:
                print('# frames', n_collected_frames, 'value_loss', value_loss,
                      'entropy', -entropy, 'ess', ess)

            #print('value update took', time.time()-st)

            # fit a policy
            #st = time.time()

            value.eval()
            player.train()
            if args.player_coeff > 0.:
                player_old.eval()

            for pi in range(n_policy):
                if numpy.mod(pi, update_every) == 0:
                    opt_player.zero_grad()
                    opt_value.zero_grad()

                #st = time.time()

                batch = replay_buffer.sample(batch_size)

                #print('batch collection took', time.time()-st)

                #st = time.time()

                #batch_x = [ex.current_['obs'] for ex in batch]
                #batch_xn = [ex.next_['obs'] for ex in batch]
                #batch_r = [ex.current_['rew'] for ex in batch]

                #print('list construction took', time.time()-st)

                #st = time.time()

                batch_x = numpy.zeros(
                    tuple([len(batch)] + list(batch[0].current_['obs'].shape)),
                    dtype='float32')
                batch_xn = numpy.zeros(
                    tuple([len(batch)] + list(batch[0].current_['obs'].shape)),
                    dtype='float32')
                batch_r = numpy.zeros((len(batch)), dtype='float32')[:, None]

                for ei, ex in enumerate(batch):
                    batch_x[ei, :] = ex.current_['obs']
                    batch_xn[ei, :] = ex.next_['obs']
                    batch_r[ei, 0] = ex.current_['rew']

                #batch_x = numpy.stack(batch_x).astype('float32')
                #batch_xn = numpy.stack(batch_xn).astype('float32')
                #batch_r = numpy.stack(batch_r).astype('float32')[:,None]

                #print('batch stack for value took', time.time()-st)

                #st = time.time()

                batch_x = torch.from_numpy(batch_x).to(args.device)
                batch_xn = torch.from_numpy(batch_xn).to(args.device)
                batch_r = torch.from_numpy(batch_r).to(args.device)

                #print('batch push for value took', time.time()-st)

                #st = time.time()

                batch_v = value(batch_x).clone().detach()
                batch_vn = value(batch_xn).clone().detach()

                #print('value forward pass took', time.time()-st)

                #st = time.time()

                batch_a = torch.from_numpy(
                    numpy.stack([ex.current_['act'] for ex in batch
                                 ]).astype('float32')[:, None]).to(args.device)
                batch_q = torch.from_numpy(
                    numpy.stack([ex.current_['prob'] for ex in batch
                                 ]).astype('float32')).to(args.device)

                batch_pi = player(batch_x)
                logp = torch.log(batch_pi.gather(1, batch_a.long()) + 1e-8)

                if args.player_coeff > 0.:
                    batch_pi_old = player_old(batch_x).clone().detach()

                #print('policy computation took', time.time()-st)

                #st = time.time()

                # entropy regularization
                ent = -(batch_pi * torch.log(batch_pi + 1e-8)).sum(1)
                if entropy == -numpy.Inf:
                    entropy = ent.mean().item()
                else:
                    entropy = 0.9 * entropy + 0.1 * ent.mean().item()

                #print('entropy computation took', time.time()-st)

                #st = time.time()

                # advantage: r(s,a) + \gamma * V(s') - V(s)
                adv = batch_r + discount_factor * batch_vn - batch_v
                #adv = adv / adv.abs().max().clamp(min=1.)

                loss = -(adv * logp).squeeze()

                loss = loss - ent_coeff * ent

                #print('basic loss computation took', time.time()-st)

                #st = time.time()

                # (clipped) importance weight:
                log_iw = logp.squeeze().clone().detach() - torch.log(batch_q +
                                                                     1e-8)
                iw = torch.exp(log_iw.clamp(max=0.))

                ess_ = torch.exp(-torch.logsumexp(2 * log_iw, dim=0)).item()
                if ess == -numpy.Inf:
                    ess = ess_
                else:
                    ess = 0.9 * ess + 0.1 * ess_

                if args.iw:
                    loss = iw * loss
                else:
                    loss = loss

                #print('importance weighting took', time.time()-st)

                if critic_aware:
                    #st = time.time()

                    pred_y = value(batch_x).squeeze()
                    pred_next = value(batch_xn).squeeze()
                    critic_loss_ = -(
                        (batch_r.squeeze() + discount_factor * pred_next -
                         pred_y)**2).clone().detach()

                    critic_loss_ = torch.exp(critic_loss_)
                    loss = loss * critic_loss_

                    #print('critic aware weighting took', time.time()-st)

                loss = loss.mean()

                if args.player_coeff > 0.:
                    #st = time.time()

                    loss_old = -(batch_pi_old *
                                 torch.log(batch_pi + 1e-8)).sum(1).mean()
                    loss = (1. - args.player_coeff
                            ) * loss + args.player_coeff * loss_old

                    #print('player interpolation took', time.time()-st)

                #st = time.time()
                loss.backward()
                if numpy.mod(pi, update_every) == (update_every - 1):
                    if clip_coeff > 0.:
                        nn.utils.clip_grad_norm_(player.parameters(),
                                                 clip_coeff)
                    opt_player.step()
                #print('backward computation and update took', time.time()-st)

            if args.player_coeff > 0.:
                copy_params(player, player_old)

            ##print('policy update took', time.time()-st)

        except KeyboardInterrupt:
            print('Terminating...')
            break

    for si in range(args.n_simulators):
        player_qs[si].put("END")

    print('Waiting for the simulators...')

    for si in range(args.n_simulators):
        simulators[-1].join()

    print('Done')
Ejemplo n.º 3
0
class DQNAgent(Agent):
    def __init__(self, action_set, reward_function, feature_extractor, 
        hidden_dims=[50, 50], learning_rate=5e-4, buffer_size=50000, 
        batch_size=64, num_batches=100, starts_learning=5000, final_epsilon=0.02, 
        discount=0.99, target_freq=10, verbose=False, print_every=1, 
        test_model_path=None):

        Agent.__init__(self, action_set, reward_function)
        self.feature_extractor = feature_extractor
        self.feature_dim = self.feature_extractor.dimension

        # build Q network
        # we use a multilayer perceptron
        dims = [self.feature_dim] + hidden_dims + [len(self.action_set)]
        self.model = MLP(dims)

        if test_model_path is None:
            self.test_mode = False
            self.learning_rate = learning_rate
            self.buffer_size = buffer_size
            self.batch_size = batch_size
            self.num_batches = num_batches
            self.starts_learning = starts_learning
            self.epsilon = 1.0  # anneals starts_learning/(starts_learning + t)
            self.final_epsilon = 0.02
            self.timestep = 0
            self.discount = discount
            
            self.buffer = Buffer(self.buffer_size)

            self.target_net = MLP(dims)
            self.target_net.load_state_dict(self.model.state_dict())
            self.target_net.eval()

            self.target_freq = target_freq # target nn updated every target_freq episodes
            self.num_episodes = 0

            self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
            
            # for debugging purposes
            self.verbose = verbose
            self.running_loss = 1.
            self.print_every = print_every

        else:
            self.test_mode = True
            self.model.load_state_dict(torch.load(test_model_path))
            self.model.eval()
        

    def __str__(self):
        return "dqn"


    def update_buffer(self, observation_history, action_history):
        """
        update buffer with data collected from current episode
        """
        reward_history = self.get_episode_reward(observation_history, action_history)
        self.cummulative_reward += np.sum(reward_history)

        tau = len(action_history)
        feature_history = np.zeros((tau+1, self.feature_extractor.dimension))
        for t in range(tau+1):
            feature_history[t] = self.feature_extractor.get_feature(observation_history[:t+1])

        for t in range(tau-1):
            self.buffer.add((feature_history[t], action_history[t], 
                reward_history[t], feature_history[t+1]))
        done = observation_history[tau][1]
        if done:
            feat_next = None
        else:
            feat_next = feature_history[tau]
        self.buffer.add((feature_history[tau-1], action_history[tau-1], 
            reward_history[tau-1], feat_next))


    def learn_from_buffer(self):
        """
        update Q network by applying TD steps
        """
        if self.timestep < self.starts_learning:
            pass

        for _ in range(self.num_batches):
            minibatch = self.buffer.sample(batch_size=self.batch_size)
            
            feature_batch = torch.zeros(self.batch_size, self.feature_dim)
            action_batch = torch.zeros(self.batch_size, 1, dtype=torch.long)
            reward_batch = torch.zeros(self.batch_size, 1)
            non_terminal_idxs = []
            next_feature_batch = []
            for i, d in enumerate(minibatch):
                x, a, r, x_next = d
                feature_batch[i] = torch.from_numpy(x)
                action_batch[i] = torch.tensor(a, dtype=torch.long)
                reward_batch[i] = r
                if x_next is not None:
                    non_terminal_idxs.append(i)
                    next_feature_batch.append(x_next)

            model_estimates = self.model(feature_batch).gather(1, action_batch)
            future_values = torch.zeros(self.batch_size)
            if next_feature_batch != []:
                next_feature_batch = torch.tensor(next_feature_batch, dtype=torch.float)
                future_values[non_terminal_idxs] = self.target_net(next_feature_batch).max(1)[0].detach()
            future_values = future_values.unsqueeze(1)
            target_values = reward_batch + self.discount * future_values

            loss = nn.functional.mse_loss(model_estimates, target_values)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            self.running_loss = 0.99 * self.running_loss + 0.01 * loss.item()

        self.epsilon = self.starts_learning / (self.starts_learning + self.timestep)
        self.epsilon = max(self.final_epsilon, self.epsilon)

        self.num_episodes += 1

        if self.verbose and (self.num_episodes % self.print_every == 0):
            print("dqn ep %d, running loss %.2f" % (self.num_episodes, self.running_loss))

        if self.num_episodes % self.target_freq == 0:
            self.target_net.load_state_dict(self.model.state_dict())
            if self.verbose:
                print("dqn ep %d update target network" % self.num_episodes)

    def act(self, observation_history, action_history):
        """ select action according to an epsilon greedy policy with respect to 
        the Q network """
        feature = self.feature_extractor.get_feature(observation_history)
        with torch.no_grad():
            action_values = self.model(torch.from_numpy(feature).float()).numpy()
        if not self.test_mode:
            action = self._epsilon_greedy_action(action_values, self.epsilon)
            self.timestep += 1
        else:
            action = self._random_argmax(action_values)
        return action

    def save(self, path=None):
        if path is None:
            path = './dqn.pt'
        torch.save(self.model.state_dict(), path)