コード例 #1
0
 def declare_memory(self):
     if self.fix_buffer:
         self.memory = ExperienceReplayMemory(
             self.experience_replay_size
         ) if not self.priority_replay else PrioritizedReplayMemory(
             self.experience_replay_size, self.priority_alpha,
             self.priority_beta_start, self.priority_beta_frames)
     else:
         self.memory = MutExperienceReplayMemory(
             self.experience_replay_size
         ) if not self.priority_replay else MutPrioritizedReplayMemory(
             self.priority_alpha, self.priority_beta_start,
             self.priority_beta_frames)
コード例 #2
0
ファイル: GANDQN-celluar.py プロジェクト: zhikunch/GAN-DDQN
 def declare_memory(self):
     self.memory = ExperienceReplayMemory(self.experience_replay_size)
コード例 #3
0
ファイル: GANDQN-celluar.py プロジェクト: zhikunch/GAN-DDQN
class WGAN_GP_Agent(object):
    def __init__(self, static_policy, num_input, num_actions):
        super(WGAN_GP_Agent, self).__init__()
        # parameters
        # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = torch.device('cuda')

        self.gamma = 0.75
        self.lr_G = 1e-4
        self.lr_D = 1e-4
        self.target_net_update_freq = 10
        self.experience_replay_size = 2000
        self.batch_size = 32
        self.update_freq = 200
        self.learn_start = 0
        self.tau = 0.1                        # default is 0.005

        self.static_policy = False
        self.num_feats = num_input
        self.num_actions = num_actions
        self.z_dim = 32
        self.num_samples = 32

        self.lambda_ = 10
        self.n_critic = 5  # the number of iterations of the critic per generator iteration1
        self.n_gen = 1

        self.declare_networks()

        self.G_target_model.load_state_dict(self.G_model.state_dict())
        self.G_optimizer = optim.Adam(self.G_model.parameters(), lr=self.lr_G, betas=(0.5, 0.999))
        self.D_optimizer = optim.Adam(self.D_model.parameters(), lr=self.lr_D, betas=(0.5, 0.999))

        self.G_model = self.G_model.to(self.device)
        self.G_target_model = self.G_target_model.to(self.device)
        self.D_model = self.D_model.to(self.device)

        if self.static_policy:
            self.G_model.eval()
            self.D_model.eval()
        else:
            self.G_model.train()
            self.D_model.train()

        self.update_count = 0
        self.nsteps = 1
        self.nstep_buffer = []

        self.declare_memory()

        self.train_hist = {}
        self.train_hist['D_loss'] = []
        self.train_hist['G_loss'] = []
        
        self.one = torch.tensor([1], device=self.device, dtype=torch.float)
        self.mone = self.one * -1

        self.batch_normalization = nn.BatchNorm1d(self.batch_size).to(self.device)

    def declare_networks(self):
        # Output the probability of each sample
        self.G_model = Generator(self.num_feats, self.num_actions, self.num_samples, self.z_dim) # output: batch_size x (num_actions*num_samples)
        self.G_target_model = Generator(self.num_feats, self.num_actions, self.num_samples, self.z_dim)
        self.D_model = Discriminator(self.num_samples, 1) # input: batch_size x num_samples output: batch_size

    def declare_memory(self):
        self.memory = ExperienceReplayMemory(self.experience_replay_size)

    def append_to_replay(self, s, a, r, s_):
        self.memory.push((s, a, r, s_))

    def save_w(self):
            if not os.path.exists('./saved_agents/GANDDQN'):
                os.makedirs('./saved_agents/GANDDQN')
            torch.save(self.G_model.state_dict(), './saved_agents/GANDDQN/G_model_10M_0.01.dump')
            torch.save(self.D_model.state_dict(), './saved_agents/GANDDQN/D_model_10M_0.01.dump')

    def save_replay(self):
        pickle.dump(self.memory, open('./saved_agents/exp_replay_agent.dump', 'wb'))

    def load_replay(self):
        fname = './saved_agents/exp_replay_agent.dump'
        if os.path.isfile(fname):
            self.memory = pickle.load(open(fname, 'rb'))

    def load_w(self):
        fname_G_model = './saved_agents/G_model_0.dump'
        fname_D_model = './saved_agents/D_model_0.dump'

        if os.path.isfile(fname_G_model):
            self.G_model.load_state_dict(torch.load(fname_G_model))
            self.G_target_model.load_state_dict(self.G_model.state_dict())
        
        if os.path.isfile(fname_D_model):
            self.D_model.load_state_dict(torch.load(fname_D_model))

    def plot_loss(self):
        plt.figure(2)
        plt.clf()
        plt.title('Training loss')
        plt.xlabel('Episode')
        plt.ylabel('Loss')
        plt.plot(self.train_hist['G_loss'], 'r')
        plt.plot(self.train_hist['D_loss'], 'b')
        plt.legend(['G_loss', 'D_loss'])
        plt.pause(0.001)

    def prep_minibatch(self, prev_t, t):
        transitions = self.memory.determine_sample(prev_t, t)

        batch_state, batch_action, batch_reward, batch_next_state = zip(*transitions)

        batch_state = torch.tensor(batch_state).to(torch.float).to(self.device)
        batch_action = torch.tensor(batch_action, device=self.device, dtype=torch.long).view(-1, 1)
        batch_reward = torch.tensor(batch_reward, device=self.device, dtype=torch.float).view(-1, 1)
        batch_next_state = torch.tensor(batch_next_state).to(torch.float).to(self.device)

        return batch_state, batch_action, batch_reward, batch_next_state

    def update_target_model(self):
        self.update_count += 1
        self.update_count = self.update_count % self.target_net_update_freq
        if self.update_count == 0:
            for target_param, param in zip(self.G_target_model.parameters(), self.G_model.parameters()):
                target_param.data.copy_(target_param.data * (1.0 - self.tau) + param.data * self.tau)
    
    def get_max_next_state_action(self, next_states, noise):
        samples = self.G_target_model(next_states, noise)
        return samples.mean(2).max(1)[1].view(next_states.size(0), 1, 1).expand(-1, -1, self.num_samples)

    def calc_gradient_penalty(self, real_data, fake_data, noise):
        alpha = torch.rand(self.batch_size, 1)
        alpha = alpha.expand(real_data.size()).to(self.device)
        interpolates = alpha * real_data.data + (1 - alpha) * fake_data.data
        interpolates.requires_grad = True

        disc_interpolates = self.D_model(interpolates, noise)
        gradients = grad(outputs=disc_interpolates, inputs=interpolates, 
                        grad_outputs=torch.ones(disc_interpolates.size()).to(self.device),
                        create_graph=True, retain_graph=True, only_inputs=True)[0]

        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.lambda_
        return gradient_penalty

    def adjust_G_lr(self, epoch):
        lr = self.lr_G * (0.1 ** (epoch // 3000))
        for param_group in self.G_optimizer.param_groups:
            param_group['lr'] = lr

    def adjust_D_lr(self, epoch):
        lr = self.lr_D * (0.1 ** (epoch // 3000))
        for param_group in self.D_optimizer.param_groups:
            param_group['lr'] = lr

    def update(self, frame=0):
        if self.static_policy:
            return None
        
        # self.append_to_replay(s, a, r, s_)

        if frame < self.learn_start:
            return None

        if frame % self.update_freq != 0:
            return None

        if self.memory.__len__() != self.experience_replay_size:
            return None
        
        print('Training.........')

        self.adjust_G_lr(frame)
        self.adjust_D_lr(frame)

        self.memory.shuffle_memory()
        len_memory = self.memory.__len__()
        memory_idx = range(len_memory)
        slicing_idx = [i for i in memory_idx[::self.batch_size]]
        slicing_idx.append(len_memory)
        # print(slicing_idx)

        self.G_model.eval()
        for t in range(len_memory // self.batch_size):
            for _ in range(self.n_critic):
                # update Discriminator
                batch_vars = self.prep_minibatch(slicing_idx[t], slicing_idx[t+1])
                batch_state, batch_action, batch_reward, batch_next_state = batch_vars
                G_noise = (torch.rand(self.batch_size, self.num_samples)).to(self.device)

                batch_action = batch_action.unsqueeze(dim=-1).expand(-1, -1, self.num_samples)

                # estimate
                current_q_values_samples = self.G_model(batch_state, G_noise) # batch_size x (num_actions*num_samples)
                current_q_values_samples = current_q_values_samples.gather(1, batch_action).squeeze(1)

                # target
                with torch.no_grad():
                    expected_q_values_samples = torch.zeros((self.batch_size, self.num_samples), device=self.device, dtype=torch.float) 
                    max_next_action = self.get_max_next_state_action(batch_next_state, G_noise)
                    expected_q_values_samples = self.G_model(batch_next_state, G_noise).gather(1, max_next_action).squeeze(1)
                    expected_q_values_samples = batch_reward + self.gamma * expected_q_values_samples

                D_noise = 0. * torch.randn(self.batch_size, self.num_samples).to(self.device)
                # WGAN-GP
                self.D_model.zero_grad()
                D_real = self.D_model(expected_q_values_samples, D_noise)
                D_real_loss = torch.mean(D_real)

                D_fake = self.D_model(current_q_values_samples, D_noise)
                D_fake_loss = torch.mean(D_fake)

                gradient_penalty = self.calc_gradient_penalty(expected_q_values_samples, current_q_values_samples, D_noise)
                
                D_loss = D_fake_loss - D_real_loss + gradient_penalty

                D_loss.backward()
                self.D_optimizer.step()

            # update G network
            self.G_model.train()
            self.G_model.zero_grad()

            # estimate
            current_q_values_samples = self.G_model(batch_state, G_noise) # batch_size x (num_actions*num_samples)
            current_q_values_samples = current_q_values_samples.gather(1, batch_action).squeeze(1)
            
            # WGAN-GP
            D_fake = self.D_model(current_q_values_samples, D_noise)
            G_loss = -torch.mean(D_fake)
            G_loss.backward()
            for param in self.G_model.parameters():
                param.grad.data.clamp_(-1, 1)
            self.G_optimizer.step()

            self.train_hist['G_loss'].append(G_loss.item())
            self.train_hist['D_loss'].append(D_loss.item())

            self.update_target_model()

        print('current q value', current_q_values_samples.mean(1))
        print('expected q value', expected_q_values_samples.mean(1))
コード例 #4
0
 def declare_memory(self):
     self.memory = ExperienceReplayMemory(
         self.experience_replay_size
     ) if not self.priority_replay else PrioritizedReplayMemory(
         self.experience_replay_size, self.priority_alpha,
         self.priority_beta_start, self.priority_beta_frames)
コード例 #5
0
class Model(BaseAgent):
    def __init__(self,
                 static_policy=False,
                 env=None,
                 config=None,
                 log_dir='/tmp/gym'):
        super(Model, self).__init__(config=config, env=env, log_dir=log_dir)
        self.device = config.device

        self.noisy = config.USE_NOISY_NETS
        self.priority_replay = config.USE_PRIORITY_REPLAY

        self.gamma = config.GAMMA
        self.lr = config.LR
        self.target_net_update_freq = config.TARGET_NET_UPDATE_FREQ
        self.experience_replay_size = config.EXP_REPLAY_SIZE
        self.batch_size = config.BATCH_SIZE
        self.learn_start = config.LEARN_START
        self.update_freq = config.UPDATE_FREQ
        self.sigma_init = config.SIGMA_INIT
        self.priority_beta_start = config.PRIORITY_BETA_START
        self.priority_beta_frames = config.PRIORITY_BETA_FRAMES
        self.priority_alpha = config.PRIORITY_ALPHA

        self.static_policy = static_policy
        self.num_feats = env.observation_space.shape
        self.num_actions = env.action_space.n
        self.env = env

        self.declare_networks()

        self.target_model.load_state_dict(self.model.state_dict())
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)

        #move to correct device
        self.model = self.model.to(self.device)
        self.target_model.to(self.device)

        if self.static_policy:
            self.model.eval()
            self.target_model.eval()
        else:
            self.model.train()
            self.target_model.train()

        self.update_count = 0

        self.declare_memory()

        self.nsteps = config.N_STEPS
        self.nstep_buffer = []

    def declare_networks(self):
        self.model = DQN(self.num_feats,
                         self.num_actions,
                         noisy=self.noisy,
                         sigma_init=self.sigma_init,
                         body=AtariBody)
        self.target_model = DQN(self.num_feats,
                                self.num_actions,
                                noisy=self.noisy,
                                sigma_init=self.sigma_init,
                                body=AtariBody)

    def declare_memory(self):
        self.memory = ExperienceReplayMemory(
            self.experience_replay_size
        ) if not self.priority_replay else PrioritizedReplayMemory(
            self.experience_replay_size, self.priority_alpha,
            self.priority_beta_start, self.priority_beta_frames)

    def append_to_replay(self, s, a, r, s_):
        self.nstep_buffer.append((s, a, r, s_))

        if (len(self.nstep_buffer) < self.nsteps):
            return

        R = sum([
            self.nstep_buffer[i][2] * (self.gamma**i)
            for i in range(self.nsteps)
        ])
        state, action, _, _ = self.nstep_buffer.pop(0)

        self.memory.push((state, action, R, s_))

    def prep_minibatch(self):
        # random transition batch is taken from experience replay memory
        transitions, indices, weights = self.memory.sample(self.batch_size)

        batch_state, batch_action, batch_reward, batch_next_state = zip(
            *transitions)

        shape = (-1, ) + self.num_feats

        batch_state = torch.tensor(batch_state,
                                   device=self.device,
                                   dtype=torch.float).view(shape)
        batch_action = torch.tensor(batch_action,
                                    device=self.device,
                                    dtype=torch.long).squeeze().view(-1, 1)
        batch_reward = torch.tensor(batch_reward,
                                    device=self.device,
                                    dtype=torch.float).squeeze().view(-1, 1)

        non_final_mask = torch.tensor(tuple(
            map(lambda s: s is not None, batch_next_state)),
                                      device=self.device,
                                      dtype=torch.uint8)
        try:  #sometimes all next states are false
            non_final_next_states = torch.tensor(
                [s for s in batch_next_state if s is not None],
                device=self.device,
                dtype=torch.float).view(shape)
            empty_next_state_values = False
        except:
            non_final_next_states = None
            empty_next_state_values = True

        return batch_state, batch_action, batch_reward, non_final_next_states, non_final_mask, empty_next_state_values, indices, weights

    def compute_loss(self, batch_vars):  #faster
        batch_state, batch_action, batch_reward, non_final_next_states, non_final_mask, empty_next_state_values, indices, weights = batch_vars

        #estimate
        self.model.sample_noise()
        current_q_values = self.model(batch_state).gather(1, batch_action)

        #target
        with torch.no_grad():
            max_next_q_values = torch.zeros(self.batch_size,
                                            device=self.device,
                                            dtype=torch.float).unsqueeze(dim=1)
            if not empty_next_state_values:
                max_next_action = self.get_max_next_state_action(
                    non_final_next_states)
                self.target_model.sample_noise()
                max_next_q_values[non_final_mask] = self.target_model(
                    non_final_next_states).gather(1, max_next_action)
            expected_q_values = batch_reward + (
                (self.gamma**self.nsteps) * max_next_q_values)

        diff = (expected_q_values - current_q_values)
        if self.priority_replay:
            self.memory.update_priorities(
                indices,
                diff.detach().squeeze().abs().cpu().numpy().tolist())
            loss = self.MSE(diff).squeeze() * weights
        else:
            loss = self.MSE(diff)
        loss = loss.mean()

        return loss

    def update(self, s, a, r, s_, frame=0):
        if self.static_policy:
            return None

        self.append_to_replay(s, a, r, s_)

        if frame < self.learn_start or frame % self.update_freq != 0:
            return None

        batch_vars = self.prep_minibatch()

        loss = self.compute_loss(batch_vars)

        # Optimize the model
        self.optimizer.zero_grad()
        loss.backward()

        for group in self.optimizer.param_groups:
            for p in group['params']:
                state = self.optimizer.state[p]
                if ('step' in state and state['step'] >= 1024):
                    state['step'] = 1000

        for param in self.model.parameters():
            param.grad.data.clamp_(-1, 1)
        self.optimizer.step()

        self.update_target_model()
        self.save_td(loss.item(), frame)
        self.save_sigma_param_magnitudes(frame)

    def get_action(self, s, eps=0.1):  #faster
        with torch.no_grad():
            if np.random.random() >= eps or self.static_policy or self.noisy:
                X = torch.tensor([s], device=self.device, dtype=torch.float)
                self.model.sample_noise()
                a = self.model(X).max(1)[1].view(1, 1)
                return a.item()
            else:
                return np.random.randint(0, self.num_actions)

    def update_target_model(self):
        self.update_count += 1
        self.update_count = self.update_count % self.target_net_update_freq
        if self.update_count == 0:
            self.target_model.load_state_dict(self.model.state_dict())

    def get_max_next_state_action(self, next_states):
        return self.target_model(next_states).max(dim=1)[1].view(-1, 1)

    def finish_nstep(self):
        while len(self.nstep_buffer) > 0:
            R = sum([
                self.nstep_buffer[i][2] * (self.gamma**i)
                for i in range(len(self.nstep_buffer))
            ])
            state, action, _, _ = self.nstep_buffer.pop(0)

            self.memory.push((state, action, R, None))

    def reset_hx(self):
        pass
コード例 #6
0
class WGAN_GP_Agent(object):
    def __init__(self, device, state_size, noise_size, num_actions,
                 num_particles):
        super(WGAN_GP_Agent, self).__init__()
        self.device = torch.device(device)
        self.gamma = 0.8
        self.lr_G = 1e-3
        self.lr_D = 1e-3
        self.lr_mse = 1e-3
        self.target_net_update_freq = 1
        self.experience_replay_size = 50000
        self.batch_size = 32
        self.update_freq = 1
        self.learn_start = 200
        self.tau = 0.01  # default is 0.005
        self.state_size = state_size
        self.noise_size = noise_size
        self.num_actions = num_actions
        self.num_particles = num_particles

        self.lambda_ = 10
        self.n_critic = 5  # the number of iterations of the critic per generator iteration1
        self.n_gen = 1

        self.declare_networks()

        self.G_target_model.load_state_dict(self.G_model.state_dict())
        self.G_optimizer = optim.Adam(self.G_model.parameters(),
                                      lr=self.lr_G,
                                      betas=(0.5, 0.999))
        self.D_optimizer = optim.Adam(self.D_model.parameters(),
                                      lr=self.lr_D,
                                      betas=(0.5, 0.999))
        self.G_optimizer_MSE = optim.Adam(self.G_model.parameters(),
                                          lr=self.lr_mse)

        self.G_model = self.G_model.to(self.device)
        self.G_target_model = self.G_target_model.to(self.device)
        self.D_model = self.D_model.to(self.device)

        self.G_model.train()
        self.D_model.train()

        self.update_count = 0
        self.nsteps = 1
        self.nstep_buffer = []

        self.declare_memory()

        self.train_hist = {}
        self.train_hist['D_loss'] = []
        self.train_hist['G_loss'] = []

        self.one = torch.tensor([1], device=self.device, dtype=torch.float)
        self.mone = self.one * -1

        self.batch_normalization = nn.BatchNorm1d(self.batch_size).to(
            self.device)

    def declare_networks(self):
        # Output the probability of each sample
        self.G_model = Generator(self.device, self.num_actions,
                                 self.num_particles, self.state_size,
                                 self.noise_size)
        self.G_target_model = Generator(self.device, self.num_actions,
                                        self.num_particles, self.state_size,
                                        self.noise_size)
        self.D_model = Discriminator(self.num_particles)

    def declare_memory(self):
        self.memory = ExperienceReplayMemory(self.experience_replay_size)

    def append_to_replay(self, s, a, r, s_):
        self.memory.push((s, a, r, s_))

    def save_w(self):
        if not os.path.exists('./saved_agents/Dueling_GANDDQN-v2'):
            os.makedirs('./saved_agents/Dueling_GANDDQN-v2')
        torch.save(self.G_model.state_dict(),
                   './saved_agents/Dueling_GANDDQN/Dueling_GANDDQN_G-v2.dump')
        torch.save(self.D_model.state_dict(),
                   './saved_agents/Dueling_GANDDQN/Dueling_GANDDQN_D-v2.dump')

    def save_replay(self):
        pickle.dump(self.memory,
                    open('./saved_agents/exp_replay_agent.dump', 'wb'))

    def load_replay(self):
        fname = './saved_agents/exp_replay_agent.dump'
        if os.path.isfile(fname):
            self.memory = pickle.load(open(fname, 'rb'))

    def load_w(self):
        fname_G_model = './saved_agents/Dueling_GANDDQN/Dueling_GANDDQN_G-v2.dump'
        fname_D_model = './saved_agents/Dueling_GANDDQN/Dueling_GANDDQN_D-v2.dump'

        if os.path.isfile(fname_G_model):
            self.G_model.load_state_dict(torch.load(fname_G_model))
            self.G_target_model.load_state_dict(self.G_model.state_dict())

        if os.path.isfile(fname_D_model):
            self.D_model.load_state_dict(torch.load(fname_D_model))

    def plot_loss(self):
        plt.figure(2)
        plt.clf()
        plt.title('Training loss')
        plt.xlabel('Episode')
        plt.ylabel('Loss')
        plt.plot(self.train_hist['G_loss'], 'r')
        plt.plot(self.train_hist['D_loss'], 'b')
        plt.legend(['G_loss', 'D_loss'])
        plt.pause(0.001)

    def prep_minibatch(self):
        transitions = self.memory.random_sample(self.batch_size)

        batch_state, batch_action, batch_reward, batch_next_state = zip(
            *transitions)

        batch_state = torch.tensor(batch_state).to(torch.float).to(self.device)
        batch_action = torch.tensor(batch_action,
                                    device=self.device,
                                    dtype=torch.long).view(-1, 1)
        batch_reward = torch.tensor(batch_reward,
                                    device=self.device,
                                    dtype=torch.float).view(-1, 1)
        batch_next_state = torch.tensor(batch_next_state).to(torch.float).to(
            self.device)

        return batch_state, batch_action, batch_reward, batch_next_state

    def update_target_model(self):
        self.update_count += 1
        self.update_count = self.update_count % self.target_net_update_freq
        if self.update_count == 0:
            for target_param, param in zip(self.G_target_model.parameters(),
                                           self.G_model.parameters()):
                target_param.data.copy_(target_param.data * (1.0 - self.tau) +
                                        param.data * self.tau)

    def get_max_next_state_action(self, next_states, noise):
        value_particles, advantages = self.G_target_model(next_states, noise)
        value_means = value_particles.mean(1)  # 1 x batch_size
        action_values = value_means.view(self.batch_size, 1).expand(
            self.batch_size, self.num_actions) + advantages
        return action_values.max(1)[1]

    def calc_gradient_penalty(self, real_data, fake_data):
        alpha = torch.rand(self.batch_size, 1)
        alpha = alpha.expand(real_data.size()).to(self.device)
        interpolates = alpha * real_data.data + (1 - alpha) * fake_data.data
        interpolates.requires_grad = True

        disc_interpolates = self.D_model(interpolates)
        gradients = grad(outputs=disc_interpolates,
                         inputs=interpolates,
                         grad_outputs=torch.ones(disc_interpolates.size()).to(
                             self.device),
                         create_graph=True,
                         retain_graph=True,
                         only_inputs=True)[0]

        gradient_penalty = (
            (gradients.norm(2, dim=1) - 1)**2).mean() * self.lambda_
        return gradient_penalty

    def adjust_G_lr(self, epoch):
        lr = self.lr_G * (0.1**(epoch // 3000))
        for param_group in self.G_optimizer.param_groups:
            param_group['lr'] = lr

    def adjust_D_lr(self, epoch):
        lr = self.lr_D * (0.1**(epoch // 3000))
        for param_group in self.D_optimizer.param_groups:
            param_group['lr'] = lr

    def adjust_mse_lr(self, epoch):
        lr = self.lr_mse * (0.1**(epoch // 3000))
        for param_group in self.G_optimizer_MSE.param_groups:
            param_group['lr'] = lr

    def MSE(self, x):
        return torch.mean(0.5 * x.pow(2))

    def update(self, frame=0):
        # self.append_to_replay(s, a, r, s_)
        if frame < self.learn_start:
            return None
        if frame % self.update_freq != 0:
            return None
        # if self.memory.__len__() != self.experience_replay_size:
        #     return None
        print('Training.........')
        self.adjust_G_lr(frame)
        self.adjust_D_lr(frame)
        self.adjust_mse_lr(frame)

        self.G_model.eval()
        for _ in range(self.n_critic):
            # update Discriminator
            batch_vars = self.prep_minibatch()
            batch_state, batch_action, batch_reward, batch_next_state = batch_vars
            G_noise = (torch.rand(self.batch_size,
                                  self.noise_size)).to(self.device)

            # estimate
            # current V particles are used for D networks
            current_value_particles, current_advantages = self.G_model(
                batch_state, G_noise)
            # current_q_values = current_value_particles.mean(1) + current_advantages.gather(1, batch_action.unsqueeze(dim=-1)).squeeze(1)

            # target
            with torch.no_grad():
                expected_value_particles, expected_advantages = self.G_target_model(
                    batch_next_state, G_noise)
                # add reward
                expected_value_particles = batch_reward * 0 + 1 + self.gamma * expected_value_particles

                # get expected q value
                # value_means = expected_value_particles.mean(1)  # 1 x batch_size
                # action_values = value_means.view(self.batch_size, 1).expand(self.batch_size, self.num_actions) + expected_advantages
                # expected_q_value = action_values.max(1)[0]

            # WGAN
            self.D_model.zero_grad()
            D_real = self.D_model(expected_value_particles)
            D_real_loss = torch.mean(D_real)

            D_fake = self.D_model(current_value_particles)
            D_fake_loss = torch.mean(D_fake)

            gradient_penalty = self.calc_gradient_penalty(
                expected_value_particles, current_value_particles)

            D_loss = D_fake_loss - D_real_loss + gradient_penalty

            D_loss.backward()
            self.D_optimizer.step()

        # update G network
        self.G_model.train()
        self.G_model.zero_grad()

        batch_vars = self.prep_minibatch()
        batch_state, batch_action, batch_reward, batch_next_state = batch_vars
        # estimate particles
        current_value_particles, current_advantages = self.G_model(
            batch_state, G_noise)
        D_fake = self.D_model(current_value_particles)
        G_loss_GAN = -torch.mean(D_fake)
        G_loss_GAN.backward()
        self.G_optimizer.step()

        # if frame % 3 == 0:
        self.G_model.zero_grad()
        batch_vars = self.prep_minibatch()
        batch_state, batch_action, batch_reward, batch_next_state = batch_vars
        # estimate particles
        current_value_particles, current_advantages = self.G_model(
            batch_state, G_noise)
        # target particles
        expected_value_particles, expected_advantages = self.G_target_model(
            batch_next_state, G_noise)
        # get estimated q value
        current_value_particles_mean = current_value_particles.mean(1)
        current_advantages_specific = current_advantages.gather(
            1, batch_action).squeeze(1)
        current_q_value = current_value_particles_mean + current_advantages_specific
        # get expected q value
        value_means = expected_value_particles.mean(1)  # 1 x batch_size
        action_values = value_means.view(self.batch_size, 1).expand(
            self.batch_size, self.num_actions) + expected_advantages
        max_action_value = action_values.max(1)[0]
        expected_q_value = self.gamma * max_action_value + batch_reward.squeeze(
            1)

        MSE_loss = self.MSE(current_q_value - expected_q_value)
        MSE_loss.backward()
        self.G_optimizer_MSE.step()
        # G_loss = G_loss_GAN + MSE_loss

        # for param in self.G_model.parameters():
        #     param.grad.data.clamp_(-1, 1)

        # self.train_hist['G_loss'].append(G_loss.item())
        # self.train_hist['D_loss'].append(D_loss.item())

        self.update_target_model()

        print('current q value', current_q_value)
        print('expected q value', expected_q_value)
        print('MSE loss', MSE_loss)