Пример #1
0
class TD3:
    def __init__(self, config: Config):
        self.config = config
        self.is_training = True
        # self.buffer = deque(maxlen=self.config.max_buff)
        self.buffer = ReplayBuffer(self.config.max_buff)

        self.actor = Actor(self.config.state_dim, self.config.action_dim,
                           self.config.max_action)
        self.actor_target = Actor(self.config.state_dim,
                                  self.config.action_dim,
                                  self.config.max_action)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.actor_optimizer = Adam(self.actor.parameters(),
                                    lr=self.config.learning_rate)

        self.critic_1 = Critic(self.config.state_dim, self.config.action_dim)
        self.critic_1_target = Critic(self.config.state_dim,
                                      self.config.action_dim)
        self.critic_1_target.load_state_dict(self.critic_1.state_dict())
        self.critic_1_optimizer = Adam(self.critic_1.parameters(),
                                       lr=self.config.learning_rate)

        self.critic_2 = Critic(self.config.state_dim, self.config.action_dim)
        self.critic_2_target = Critic(self.config.state_dim,
                                      self.config.action_dim)
        self.critic_2_target.load_state_dict(self.critic_2.state_dict())
        self.critic_2_optimizer = Adam(self.critic_2.parameters(),
                                       lr=self.config.learning_rate)

        self.MseLoss = nn.MSELoss()

        if self.config.use_cuda:
            self.cuda()

    def act(self, state):
        state = torch.FloatTensor(state.reshape(1, -1)).to(device)
        action = self.actor(state)
        return action.cpu().data.numpy().flatten()  #.detach()

    def learning(self, fr, t):

        for i in range(t):
            state, action_, reward, next_state, done = self.buffer.sample(
                self.config.batch_size)

            state = torch.tensor(state, dtype=torch.float).to(device)
            next_state = torch.tensor(next_state, dtype=torch.float).to(device)
            action = torch.tensor(action_, dtype=torch.float).to(device)
            reward = torch.tensor(reward, dtype=torch.float).reshape(
                (-1, 1)).to(device)
            done = torch.tensor(done, dtype=torch.float).reshape(
                (-1, 1)).to(device)
            # reward = torch.FloatTensor(reward).reshape((self.config.batch_size,1)).to(device)
            # done = torch.FloatTensor(done).reshape((self.config.batch_size,1)).to(device)

            # Select next action according to target policy:
            noise = torch.tensor(action_, dtype=torch.float).data.normal_(
                0, self.config.policy_noise).to(device)
            noise = noise.clamp(-self.config.noise_clip,
                                self.config.noise_clip)
            next_action = (self.actor_target(next_state) + noise)
            next_action = next_action.clamp(-self.config.max_action,
                                            self.config.max_action)

            # Compute target Q-value:
            target_Q1 = self.critic_1_target(next_state, next_action)
            target_Q2 = self.critic_2_target(next_state, next_action)
            target_Q = torch.min(target_Q1, target_Q2)
            target_Q = reward + (
                (1 - done) * self.config.gamma * target_Q).detach()

            # Optimize Critic 1:
            current_Q1 = self.critic_1(state, action)
            loss_Q1 = F.mse_loss(current_Q1, target_Q)
            self.critic_1_optimizer.zero_grad()
            loss_Q1.backward()
            self.critic_1_optimizer.step()

            # Optimize Critic 2:
            current_Q2 = self.critic_2(state, action)
            loss_Q2 = F.mse_loss(current_Q2, target_Q)
            self.critic_2_optimizer.zero_grad()
            loss_Q2.backward()
            self.critic_2_optimizer.step()

            # Delayed policy updates:
            if i % self.config.policy_delay == 0:
                # Compute actor loss:
                actor_loss = -self.critic_1(state, self.actor(state)).mean()

                # Optimize the actor
                self.actor_optimizer.zero_grad()
                actor_loss.backward()
                self.actor_optimizer.step()

                # Polyak averaging update:
                for param, target_param in zip(self.actor.parameters(),
                                               self.actor_target.parameters()):
                    target_param.data.copy_(
                        (self.config.polyak * target_param.data) +
                        ((1 - self.config.polyak) * param.data))

                for param, target_param in zip(
                        self.critic_1.parameters(),
                        self.critic_1_target.parameters()):
                    target_param.data.copy_(
                        (self.config.polyak * target_param.data) +
                        ((1 - self.config.polyak) * param.data))

                for param, target_param in zip(
                        self.critic_2.parameters(),
                        self.critic_2_target.parameters()):
                    target_param.data.copy_(
                        (self.config.polyak * target_param.data) +
                        ((1 - self.config.polyak) * param.data))

    def cuda(self):
        self.actor.to(device)
        self.actor_target.to(device)
        self.critic_1.to(device)
        self.critic_1_target.to(device)
        self.critic_2.to(device)
        self.critic_2_target.to(device)

    def load_weights(self, model_path):
        policy = torch.load(model_path)
        if 'actor' in policy:
            self.actor.load_state_dict(policy['actor'])
        else:
            self.actor.load_state_dict(policy)

    def save_model(self, output, name=''):
        torch.save(self.actor.state_dict(), '%s/actor_%s.pkl' % (output, name))

    def save_config(self, output):
        with open(output + '/config.txt', 'w') as f:
            attr_val = get_class_attr_val(self.config)
            for k, v in attr_val.items():
                f.write(str(k) + " = " + str(v) + "\n")

    def save_checkpoint(self, fr, output):
        checkpath = output + '/checkpoint_policy'
        os.makedirs(checkpath, exist_ok=True)
        torch.save(
            {
                'frames': fr,
                'actor': self.actor.state_dict(),
                'critic_1': self.critic_1.state_dict(),
                'critic_2': self.critic_2.state_dict(),
            }, '%s/checkpoint_fr_%d.tar' % (checkpath, fr))

    def load_checkpoint(self, model_path):
        checkpoint = torch.load(model_path)
        fr = checkpoint['frames']
        self.actor.load_state_dict(checkpoint['actor'])
        self.critic_1.load_state_dict(checkpoint['critic_1'])
        self.critic_2.load_state_dict(checkpoint['critic_2'])
        return fr
Пример #2
0
class DDPG(object):
    def __init__(self, args, nb_states, nb_actions):
        USE_CUDA = torch.cuda.is_available()
        if args.seed > 0:
            self.seed(args.seed)

        self.nb_states =  nb_states
        self.nb_actions= nb_actions
        self.gpu_ids = [i for i in range(args.gpu_nums)] if USE_CUDA and args.gpu_nums > 0 else [-1]
        self.gpu_used = True if self.gpu_ids[0] >= 0 else False

        net_cfg = {
            'hidden1':args.hidden1,
            'hidden2':args.hidden2,
            'init_w':args.init_w
        }
        self.actor = Actor(self.nb_states, self.nb_actions, **net_cfg).double()
        self.actor_target = Actor(self.nb_states, self.nb_actions, **net_cfg).double()
        self.actor_optim  = Adam(self.actor.parameters(), lr=args.p_lr, weight_decay=args.weight_decay)

        self.critic = Critic(self.nb_states, self.nb_actions, **net_cfg).double()
        self.critic_target = Critic(self.nb_states, self.nb_actions, **net_cfg).double()
        self.critic_optim  = Adam(self.critic.parameters(), lr=args.c_lr, weight_decay=args.weight_decay)

        hard_update(self.actor_target, self.actor) # Make sure target is with the same weight
        hard_update(self.critic_target, self.critic)
        
        #Create replay buffer
        self.memory = SequentialMemory(limit=args.rmsize, window_length=args.window_length)
        self.random_process = OrnsteinUhlenbeckProcess(size=self.nb_actions,
                                                       theta=args.ou_theta, mu=args.ou_mu, sigma=args.ou_sigma)

        # Hyper-parameters
        self.batch_size = args.bsize
        self.tau_update = args.tau_update
        self.gamma = args.gamma

        # Linear decay rate of exploration policy
        self.depsilon = 1.0 / args.epsilon
        # initial exploration rate
        self.epsilon = 1.0
        self.s_t = None # Most recent state
        self.a_t = None # Most recent action
        self.is_training = True

        self.continious_action_space = False

    def update_policy(self):
        pass

    def cuda_convert(self):
        if len(self.gpu_ids) == 1:
            if self.gpu_ids[0] >= 0:
                with torch.cuda.device(self.gpu_ids[0]):
                    print('model cuda converted')
                    self.cuda()
        if len(self.gpu_ids) > 1:
            self.data_parallel()
            self.cuda()
            self.to_device()
            print('model cuda converted and paralleled')

    def eval(self):
        self.actor.eval()
        self.actor_target.eval()
        self.critic.eval()
        self.critic_target.eval()

    def cuda(self):
        self.actor.cuda()
        self.actor_target.cuda()
        self.critic.cuda()
        self.critic_target.cuda()

    def data_parallel(self):
        self.actor = nn.DataParallel(self.actor, device_ids=self.gpu_ids)
        self.actor_target = nn.DataParallel(self.actor_target, device_ids=self.gpu_ids)
        self.critic = nn.DataParallel(self.critic, device_ids=self.gpu_ids)
        self.critic_target = nn.DataParallel(self.critic_target, device_ids=self.gpu_ids)

    def to_device(self):
        self.actor.to(torch.device('cuda:{}'.format(self.gpu_ids[0])))
        self.actor_target.to(torch.device('cuda:{}'.format(self.gpu_ids[0])))
        self.critic.to(torch.device('cuda:{}'.format(self.gpu_ids[0])))
        self.critic_target.to(torch.device('cuda:{}'.format(self.gpu_ids[0])))

    def observe(self, r_t, s_t1, done):
        if self.is_training:
            self.memory.append(self.s_t, self.a_t, r_t, done)
            self.s_t = s_t1

    def random_action(self):
        action = np.random.uniform(-1.,1.,self.nb_actions)
        # self.a_t = action
        return action

    def select_action(self, s_t, decay_epsilon=True):
        # proto action
        action = to_numpy(
            self.actor(to_tensor(np.array([s_t]), gpu_used=self.gpu_used, gpu_0=self.gpu_ids[0])),
            gpu_used=self.gpu_used
        ).squeeze(0)
        action += self.is_training * max(self.epsilon, 0) * self.random_process.sample()
        action = np.clip(action, -1., 1.)

        if decay_epsilon:
            self.epsilon -= self.depsilon
        
        # self.a_t = action
        return action

    def reset(self, s_t):
        self.s_t = s_t
        self.random_process.reset_states()

    def load_weights(self, dir):
        if dir is None: return

        if self.gpu_used:
            # load all tensors to GPU (gpu_id)
            ml = lambda storage, loc: storage.cuda(self.gpu_ids)
        else:
            # load all tensors to CPU
            ml = lambda storage, loc: storage

        self.actor.load_state_dict(
            torch.load('output/{}/actor.pkl'.format(dir), map_location=ml)
        )

        self.critic.load_state_dict(
            torch.load('output/{}/critic.pkl'.format(dir), map_location=ml)
        )
        print('model weights loaded')


    def save_model(self,output):
        if len(self.gpu_ids) == 1 and self.gpu_ids[0] > 0:
            with torch.cuda.device(self.gpu_ids[0]):
                torch.save(
                    self.actor.state_dict(),
                    '{}/actor.pt'.format(output)
                )
                torch.save(
                    self.critic.state_dict(),
                    '{}/critic.pt'.format(output)
                )
        elif len(self.gpu_ids) > 1:
            torch.save(self.actor.module.state_dict(),
                       '{}/actor.pt'.format(output)
            )
            torch.save(self.actor.module.state_dict(),
                       '{}/critic.pt'.format(output)
                       )
        else:
            torch.save(
                self.actor.state_dict(),
                '{}/actor.pt'.format(output)
            )
            torch.save(
                self.critic.state_dict(),
                '{}/critic.pt'.format(output)
            )

    def seed(self,seed):
        torch.manual_seed(seed)
        if len(self.gpu_ids) > 0:
            torch.cuda.manual_seed_all(seed)
Пример #3
0
class DDPG(object):
    def __init__(self, nb_states, nb_actions, args):
        self.nb_states = nb_states
        self.nb_actions = nb_actions
        self.discrete = args.discrete

        net_config = {
            'hidden1' : args.hidden1,
            'hidden2' : args.hidden2
        }

        # Actor and Critic initialization
        self.actor = Actor(self.nb_states, self.nb_actions, **net_config)
        self.actor_target = Actor(self.nb_states, self.nb_actions, **net_config)
        self.actor_optim = Adam(self.actor.parameters(), lr=args.actor_lr)

        self.critic = Critic(self.nb_states, self.nb_actions, **net_config)
        self.critic_target = Critic(self.nb_states, self.nb_actions, **net_config)
        self.critic_optim = Adam(self.critic.parameters(), lr=args.critic_lr)

        hard_update(self.critic_target, self.critic)
        hard_update(self.actor_target, self.actor)

        # Replay Buffer and noise
        self.memory = ReplayBuffer(args.memory_size)
        self.noise = OrnsteinUhlenbeckProcess(mu=np.zeros(nb_actions), sigma=float(0.2) * np.ones(nb_actions))

        self.last_state = None
        self.last_action = None

        # Hyper parameters
        self.batch_size = args.batch_size
        self.tau = args.tau
        self.discount = args.discount

        # CUDA
        self.use_cuda = args.cuda
        if self.use_cuda:
            self.cuda()

    def cuda(self):
        self.actor.to(device)
        self.actor_target.to(device)
        self.critic.to(device)
        self.critic_target.to(device)

    def eval(self):
        self.actor.eval()
        self.actor_target.eval()
        self.critic.eval()
        self.critic_target.eval()

    def train(self):
        self.actor.train()
        self.actor_target.train()
        self.critic.train()
        self.critic_target.train()

    def reset(self, obs):
        self.last_state = obs
        self.noise.reset()

    def observe(self, reward, state, done):
        self.memory.append([self.last_state, self.last_action, reward, state, done])
        self.last_state = state

    def random_action(self):
        action = np.random.uniform(-1., 1., self.nb_actions)
        self.last_action = action
        return action.argmax() if self.discrete else action

    def select_action(self, state, apply_noise=False):
        self.eval()
        action = to_numpy(self.actor(to_tensor(np.array([state]), device=device))).squeeze(0)
        self.train()
        if apply_noise:
            action = action + self.noise.sample()
        action = np.clip(action, -1., 1.)
        self.last_action = action
        #print('action:', action, 'output:', action.argmax())
        return action.argmax() if self.discrete else action

    def update_policy(self):
        state_batch, action_batch, reward_batch, next_state_batch, terminal_batch = self.memory.sample_batch(self.batch_size)
        state = to_tensor(np.array(state_batch), device=device)
        action = to_tensor(np.array(action_batch), device=device)
        next_state = to_tensor(np.array(next_state_batch), device=device)

        # compute target Q value
        next_q_value = self.critic_target([next_state, self.actor_target(next_state)])
        target_q_value = to_tensor(reward_batch, device=device) \
                         + self.discount * to_tensor((1 - terminal_batch.astype(np.float)), device=device) * next_q_value

        # Critic and Actor update
        self.critic.zero_grad()
        with torch.set_grad_enabled(True):
            q_values = self.critic([state, action])
            critic_loss = criterion(q_values, target_q_value.detach())
            critic_loss.backward()
            self.critic_optim.step()

        self.actor.zero_grad()
        with torch.set_grad_enabled(True):
            policy_loss = -self.critic([state.detach(), self.actor(state)]).mean()
            policy_loss.backward()
            self.actor_optim.step()

        # Target update
        soft_update(self.actor_target, self.actor, self.tau)
        soft_update(self.critic_target, self.critic, self.tau)

        return to_numpy(-policy_loss), to_numpy(critic_loss), to_numpy(q_values.mean())

    def save_model(self, output, num=1):
        if self.use_cuda:
            self.actor.to(torch.device("cpu"))
            self.critic.to(torch.device("cpu"))
        torch.save(self.actor.state_dict(), '{}/actor{}.pkl'.format(output, num))
        torch.save(self.critic.state_dict(), '{}/critic{}.pkl'.format(output, num))
        if self.use_cuda:
            self.actor.to(device)
            self.critic.to(device)

    def load_model(self, output, num=1):
        self.actor.load_state_dict(torch.load('{}/actor{}.pkl'.format(output, num)))
        self.actor_target.load_state_dict(torch.load('{}/actor{}.pkl'.format(output, num)))
        self.critic.load_state_dict(torch.load('{}/critic{}.pkl'.format(output, num)))
        self.critic_target.load_state_dict(torch.load('{}/critic{}.pkl'.format(output, num)))
        if self.use_cuda:
            self.cuda()
class DDPG:
    def __init__(self,
                state_size,
                action_size,                
                tau,
                lr_actor,
                lr_critic,
                num_agents,
                agent_idx,
                seed,
                device,
                gamma,
                tensorboard_writer=None):
        
        self.state_size = state_size
        self.action_size = action_size
        self.tau = tau
        self.lr_actor = lr_actor
        self.lr_critic = lr_critic
        self.num_agents = num_agents
        self.agent_idx = agent_idx
        self.seed = seed       
        self.device = device
        self.gamma = gamma
        random.seed(seed)
        self.tensorboard_writer = tensorboard_writer        
        
        self.actor_local = Actor(state_size, action_size, seed)
        self.actor_target = Actor(state_size, action_size, seed)
        
        critic_state_size = (state_size + action_size) * num_agents
        
        self.critic_local = Critic(critic_state_size, seed)
        self.critic_target = Critic(critic_state_size, seed)
        
        hard_update(self.actor_local, self.actor_target)
        hard_update(self.critic_local, self.critic_target) 
        
        self.actor_optim = torch.optim.Adam(self.actor_local.parameters(), lr=lr_actor)
        self.critic_optim = torch.optim.Adam(self.critic_local.parameters(), lr=lr_critic)
        
        self.noise = OUNoise(action_size, seed)
        
        self.iteration = 0
        
    def to(self, device):
        self.actor_local.to(device)
        self.actor_target.to(device)
        self.critic_local.to(device)
        self.critic_target.to(device)
        return self
                             
    def act(self, state, noise_scale, use_noise=True):
        state = torch.from_numpy(state).float().to(self.device)
        self.actor_local.eval()
        with torch.no_grad():
            action = self.actor_local(state).cpu().data.numpy()
        self.actor_local.train()
        if use_noise:
            action += self.noise.sample() * noise_scale
        return np.clip(action, -1, 1)
    
    def learn(self, experiences, all_curr_pred_actions, all_next_pred_actions):
        
        agent_idx_device = torch.tensor(self.agent_idx).to(self.device)
        
        states, actions, rewards, next_states, dones = experiences

        rewards = rewards.index_select(1, agent_idx_device)
        dones = dones.index_select(1, agent_idx_device)
        
        # ---------------------------- update critic ---------------------------- #
        # Get predicted next-state actions and Q values from target models
                
        batch_size = next_states.shape[0]
        
        actions_next = torch.cat(all_next_pred_actions, dim=1).to(self.device)
        next_states = next_states.reshape(batch_size, -1)      
        
        with torch.no_grad():
            Q_targets_next = self.critic_target(next_states, actions_next)
        
        # Compute Q targets for current states (y_i)
        Q_targets = rewards + (self.gamma * Q_targets_next * (1 - dones))
        
        # Compute critic loss
        states = states.reshape(batch_size, -1)
        actions = actions.reshape(batch_size, -1)
        
        Q_expected = self.critic_local(states, actions)
        critic_loss = F.mse_loss(Q_expected, Q_targets.detach())
        # Minimize the loss
        self.critic_optim.zero_grad()
        critic_loss.backward()
        self.critic_optim.step()
        
        # ---------------------------- update actor ---------------------------- #
        # Compute actor loss
        self.actor_optim.zero_grad()
        predicted_actions = torch.cat([action if idx == self.agent_idx \
                   else action.detach()
                   for idx, action in enumerate(all_curr_pred_actions)],
                   dim=1).to(self.device)

        actor_loss = -self.critic_local(states, predicted_actions).mean()
        # minimize loss
        actor_loss.backward()
        self.actor_optim.step()
        
        al = actor_loss.cpu().detach().item()
        cl = critic_loss.cpu().detach().item()
        
        if self.tensorboard_writer is not None:            
            self.tensorboard_writer.add_scalar("agent{}/actor_loss".format(self.agent_idx), al, self.iteration)
            self.tensorboard_writer.add_scalar("agent{}/critic_loss".format(self.agent_idx), cl, self.iteration)
            self.tensorboard_writer.file_writer.flush()
            
        self.iteration += 1

        # ----------------------- update target networks ----------------------- #
        soft_update(self.critic_target, self.critic_local, self.tau)
        soft_update(self.actor_target, self.actor_local, self.tau)           

    
    def reset(self):
        self.noise.reset()
Пример #5
0
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
gen = Generator(
    im_ch=IMAGE_CHANNELS,
    latent_dim=NOISE_DIM,
    hidden_dim=HIDDEN_DIM_GEN,
    use_batchnorm=USE_BATCHNORM,
    upsample_mode=UPSAMPLE_MODE,
)
gen = gen.to(device)
critic = Critic(
    im_ch=IMAGE_CHANNELS,
    hidden_dim=HIDDEN_DIM_DISC,
    use_batchnorm=USE_BATCHNORM,
    spectral_norm=SPECTRAL_NORM,
)
critic = critic.to(device)

critic.apply(init_weights)
gen.apply(init_weights)

# configure loss and optimizers
criterion = nn.BCEWithLogitsLoss()
opt_gen = torch.optim.Adam(gen.parameters(), lr=LR, betas=(beta1, beta2))
opt_disc = torch.optim.Adam(critic.parameters(), lr=LR, betas=(beta1, beta2))

# configure tensorboard writer
repo = git.Repo(search_parent_directories=True)
sha = repo.head.object.hexsha[:6]
logdir = f"/home/bishwarup/GAN_experiments/dcgan/{sha}"
writer = SummaryWriter(log_dir=logdir)
Пример #6
0
class DDPG(object):
    def __init__(self):

        # random seed for torch
        __seed = config.get(MODEL_SEED)
        self.policy_loss = []
        self.critic_loss = []
        if __seed > 0:
            self.seed(__seed)

        self.nb_states = config.get(MODEL_STATE_COUNT)
        self.nb_actions = config.get(MODEL_ACTION_COUNT)

        # Create Actor and Critic Network
        actor_net_cfg = {
            'hidden1': config.get(MODEL_ACTOR_HIDDEN1),
            'hidden2': config.get(MODEL_ACTOR_HIDDEN2),
            'init_w': config.get(MODEL_INIT_WEIGHT)
        }
        critic_net_cfg = {
            'hidden1': config.get(MODEL_CRITIC_HIDDEN1),
            'hidden2': config.get(MODEL_CRITIC_HIDDEN2),
            'init_w': config.get(MODEL_INIT_WEIGHT)
        }
        self.actor = Actor(self.nb_states, self.nb_actions, **actor_net_cfg)
        self.actor_target = Actor(self.nb_states, self.nb_actions,
                                  **actor_net_cfg)
        self.actor_optim = Adam(
            self.actor.parameters(),
            lr=config.get(MODEL_ACTOR_LR),
            weight_decay=config.get(MODEL_ACTOR_WEIGHT_DECAY))

        self.critic = Critic(self.nb_states, self.nb_actions, **critic_net_cfg)
        self.critic_target = Critic(self.nb_states, self.nb_actions,
                                    **critic_net_cfg)
        self.critic_optim = Adam(
            self.critic.parameters(),
            lr=config.get(MODEL_CRITIC_LR),
            weight_decay=config.get(MODEL_CRITIC_WEIGHT_DECAY))

        hard_update(self.actor_target, self.actor)
        hard_update(self.critic_target, self.critic)

        #Create replay buffer
        self.memory = Memory()

        self.random_process = OrnsteinUhlenbeckProcess(
            size=self.nb_actions,
            theta=config.get(RANDOM_THETA),
            mu=config.get(RANDOM_MU),
            sigma=config.get(RANDOM_SIGMA))

        # Hyper-parameters
        self.batch_size = config.get(MODEL_BATCH_SIZE)
        self.tau = config.get(MODEL_TARGET_TAU)
        self.discount = config.get(MODEL_DISCOUNT)
        self.depsilon = 1.0 / config.get(MODEL_EPSILON)

        self.model_path = config.get(MODEL_SAVE_PATH)

        #
        self.epsilon = 1.0

        # init device
        self.device_init()

    def update_policy(self, memory):
        # Sample batch
        state_batch, action_batch, reward_batch, \
        next_state_batch, terminal_batch = memory.sample_and_split(self.batch_size)

        # Prepare for the target q batch
        with torch.no_grad():
            next_q_values = self.critic_target([
                to_tensor(next_state_batch),
                self.actor_target(to_tensor(next_state_batch))
            ])

        target_q_batch = to_tensor(reward_batch) + \
            self.discount*to_tensor(terminal_batch.astype(np.float))*next_q_values

        # Critic update
        self.critic.zero_grad()

        q_batch = self.critic(
            [to_tensor(state_batch),
             to_tensor(action_batch)])
        value_loss = F.mse_loss(q_batch, target_q_batch)

        value_loss.backward()
        self.critic_optim.step()
        self.critic_loss.append(value_loss.data[0])

        # Actor update
        self.actor.zero_grad()

        policy_loss = -self.critic(
            [to_tensor(state_batch),
             self.actor(to_tensor(state_batch))])

        policy_loss = policy_loss.mean()
        policy_loss.backward()
        self.actor_optim.step()
        self.policy_loss.append(policy_loss.data[0])

        # Target update
        soft_update(self.actor_target, self.actor, self.tau)
        soft_update(self.critic_target, self.critic, self.tau)

    def get_loss(self):
        return self.policy_loss, self.critic_loss

    def eval(self):
        self.actor.eval()
        self.actor_target.eval()
        self.critic.eval()
        self.critic_target.eval()

    def device_init(self):
        self.actor.to(device)
        self.actor_target.to(device)
        self.critic.to(device)
        self.critic_target.to(device)

    def random_action(self):
        action = np.random.uniform(-1., 1., self.nb_actions)
        return action

    def select_action(self, s_t):
        action = to_numpy(self.actor(to_tensor(np.array([s_t])))).squeeze(0)

        action += max(self.epsilon, 0) * self.random_process.sample()
        action = np.clip(action, -1., 1.)

        return action

    def clean(self, decay_epsilon):
        if decay_epsilon:
            self.epsilon -= self.depsilon

    def reset(self):
        self.random_process.reset_states()

    def load_weights(self):
        if not os.path.exists(self.model_path):
            return

        actor_path = os.path.exists(os.path.join(self.model_path, 'actor.pkl'))
        if os.path.exists(actor_path):
            self.actor.load_state_dict(torch.load(actor_path))

        critic_path = os.path.exists(
            os.path.join(self.model_path, 'critic.pkl'))
        if os.path.exists(critic_path):
            self.critic.load_state_dict(torch.load(critic_path))

    def save_model(self):
        if not os.path.exists(self.model_path):
            os.makedirs(self.model_path)
        actor_path = os.path.exists(os.path.join(self.model_path, 'actor.pkl'))
        torch.save(self.actor.state_dict(), actor_path)

        critic_path = os.path.exists(
            os.path.join(self.model_path, 'critic.pkl'))
        torch.save(self.critic.state_dict(), critic_path)

    def get_model(self):
        return self.actor.state_dict(), self.critic.state_dict()

    def load_state_dict(self, actor_state, critic_state):
        self.actor.load_state_dict(actor_state)
        self.critic.load_state_dict(critic_state)

    def seed(self, s):
        torch.manual_seed(s)
        if USE_CUDA:
            torch.cuda.manual_seed(s)