Beispiel #1
0
class Agent():
    def __init__(self, state_size, action_size, fc_units, seed, lr,
                 buffer_size, batch_size, update_every):

        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(seed)
        self.update_every = update_every

        self.qnetwork_local = QNetwork(state_size, action_size, seed,
                                       fc_units).to(device)
        self.qnetwork_target = QNetwork(state_size, action_size, seed,
                                        fc_units).to(device)
        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=lr)
        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer,
                                                   step_size=100,
                                                   gamma=0.5)
        self.memory = ReplayBuffer(action_size, buffer_size, batch_size, seed)
        self.t_step = 0

    def step(self, state, action, reward, next_state, done, gamma, tau):
        self.memory.add(state, action, reward, next_state, done)
        self.t_step = (self.t_step + 1) % self.update_every

        if (self.t_step == 0) and (len(self.memory) > self.memory.batch_size):
            experiences = self.memory.sample()
            self.learn(experiences, gamma, tau)

    def act(self, state, eps):
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        self.qnetwork_local.eval()
        with torch.no_grad():
            action_values = self.qnetwork_local(state)
        self.qnetwork_local.train()

        if random.random() > eps:
            action = np.argmax(action_values.cpu().data.numpy())
        else:
            action = random.choice(np.arange(self.action_size))
        return action

    def learn(self, experiences, gamma, tau):
        states, actions, rewards, next_states, dones = experiences
        Q_targets_next = self.qnetwork_target(next_states).detach().max(
            1)[0].unsqueeze(1)
        Q_targets = rewards + gamma * Q_targets_next * (1 - dones)
        Q_expected = self.qnetwork_local(states).gather(1, actions)
        loss = F.mse_loss(Q_expected, Q_targets)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.soft_update(self.qnetwork_local, self.qnetwork_target, tau)

    def soft_update(self, local_model, target_model, tau):
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)
Beispiel #2
0
class Agent():
    def __init__(self, state_size, action_size, seed, training, pixels, lr=LR):
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(seed)
        self.t_step = 0.
        self.pixels = pixels
        if pixels is False:
            from q_network import QNetwork
        else:
            from q_network_cnn import QNetwork
            print('loaded cnn network')
            self.loader = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
        self.QN_local = QNetwork(state_size, action_size, seed,
                                 training).to(device)
        self.QN_target = QNetwork(state_size, action_size, seed,
                                  training).to(device)
        self.optimizer = optim.Adam(self.QN_local.parameters(), lr=lr)
        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed,
                                   device)  # TODO

    def act(self, state, eps):
        #         if self.pixels is True:

        #state = Variable(torch.from_numpy(state).float().to(device).view(state.shape[0],3,32,32))
        if not self.pixels:
            state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        self.QN_local.eval()

        if torch.no_grad():
            action_values = self.QN_local(state)
        self.QN_local.train()
        if random.random() > eps:
            return int(np.argmax(action_values.cpu().data.numpy()))
        else:
            return int(random.choice(np.arange(self.action_size)))

    def step(self, state, action, reward, next_state, done, stack_size):
        self.memory.add(state, action, reward, next_state, done)
        self.t_step = (self.t_step + 1) % UPDATE_RATE
        if self.t_step == 0 and len(self.memory) > BATCH_SIZE:
            samples = self.memory.sample()
            self.learn(samples, GAMMA, stack_size)

    def learn(self, experiences, gamma, stack_size):
        states, actions, rewards, next_states, dones = experiences

        if self.pixels:
            next_states = Variable(
                next_states
            )  #next_states.view(next_states.shape[0],stack_size,3, stack_size,32,32))
            states = Variable(states)  #states.view(states.shape[0],3,64,64))
#         else:
#todo bring back the old version stuff here

        _target = self.QN_target(next_states).detach().max(1)[0].unsqueeze(1)
        action_values_target = rewards + gamma * _target * (1 - dones)
        action_values_expected = self.QN_local(states).gather(1, actions)

        loss = F.mse_loss(action_values_expected, action_values_target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # update target Qnetwork
        for target_param, local_param in zip(self.QN_target.parameters(),
                                             self.QN_local.parameters()):
            target_param.data.copy_(TAU * local_param.data +
                                    (1.0 - TAU) * target_param.data)
Beispiel #3
0
class Agent:
    def __init__(self, env, state_space, action_space, device, learning_rate, buffer_size, \
                 batch_size, gamma, in_channels, train_freq = 4, target_update_freq=1e4, is_ddqn = False):
        self.env = env
        self.state_space = state_space
        self.action_space = action_space
        
        self.QNetwork_local = QNetwork(in_channels, self.state_space, self.action_space.n, device = device).to(device)
        self.QNetwork_local.init_weights()
        self.QNetwork_target = QNetwork(in_channels, self.state_space,self.action_space.n , device = device).to(device)
        self.QNetwork_target.load_state_dict(self.QNetwork_local.state_dict())
        
        self.optimizer = torch.optim.RMSprop(self.QNetwork_local.parameters(), lr=learning_rate, alpha=0.95, eps=0.01, centered=True)
        self.criterion = torch.nn.MSELoss()
        self.memory = ReplayBuffer(capacity=int(buffer_size), batch_size=batch_size)
        self.step_count = 0.
        self.batch_size = batch_size
        self.gamma = gamma
        
        self.device = device
        self.buffer_size = buffer_size
        self.num_train_updates = 0
        self.train_freq = train_freq
        self.target_update_freq = target_update_freq

        self.is_ddqn = is_ddqn
        
        print('Agent initialized with memory size:{}'.format(buffer_size))
        
    def act(self, state, epsilon):
        if random.random() > epsilon:
            #switch to evaluation mode, evaluate state, switch back to train mode
            self.QNetwork_local.eval()
            if torch.no_grad():
                actions = self.QNetwork_local(state)
            self.QNetwork_local.train()
            
            actions_np = actions.data.cpu().numpy()[0]
            best_action_idx = int(np.argmax(actions_np))
            return best_action_idx
        else:
            rand_action = self.action_space.sample()
            return rand_action
            
        
    def step(self, state, action, reward, next_state , done, add_memory=True):
        #TODO calculate priority here?
        if add_memory: 
            priority = 1.
            reward_clip = np.sign(reward)
            self.memory.add(state=state, action=action, next_state=next_state, reward=reward_clip, done=done, priority=priority)
            self.step_count = (self.step_count+1) % self.train_freq  #self.update_rate
            
            #if self.step_count == 0 and len(self.memory) >= self.batch_size:
            self.network_is_updated = False
            if self.step_count == 0 and len(self.memory) == self.buffer_size:
                samples = self.memory.random_sample(self.device)
                self.learn(samples)
                self.num_train_updates +=1
                self.network_is_updated = True
            
    def learn(self, samples):
        states, actions, rewards, next_states, dones = samples
        

        if self.is_ddqn is True:
            # DDQN: find max action using local network & gather the values of actions from target network
            next_actions = torch.argmax(self.QNetwork_local(next_states).detach(), dim=1).unsqueeze(1)
            q_target_next = self.QNetwork_target(next_states).gather(1,next_actions)
        else:
            # DQN: find the max action from target network
            q_target_next = self.QNetwork_target(next_states).detach().max(1)[0].unsqueeze(1)
    
        # expected actions
        q_local_current = self.QNetwork_local(states).gather(1,actions)

        self.optimizer.zero_grad() #cleans up previous values

        # TD Error 
        TD_target = rewards + (self.gamma*q_target_next * (1-dones))
        TD_error = self.criterion(q_local_current, TD_target)
        TD_error.backward()
        torch.nn.utils.clip_grad_norm_(self.QNetwork_local.parameters(), 5.)
        self.optimizer.step()
        
        if (self.num_train_updates/self.train_freq) % self.target_update_freq == 0:
            self.QNetwork_target.load_state_dict(self.QNetwork_local.state_dict())
Beispiel #4
0
def deep_qlearning(env, nframes, discount_factor, N, C, mini_batch_size,
                   replay_start_size, sgd_update_frequency,
                   initial_exploration, final_exploration,
                   final_exploration_frame, lr, alpha, m):
    """
    Input:
    - env: environment
    - nframes: # of frames to train on
    - discount_factor (gamma): how much to discount future rewards
    - N: replay memory size
    - C: number of steps before updating Q target network
    - mini_batch_size: mini batch size
    - replay_start_size: minimum size of replay memory before learning starts
    - sgd_update_frequency: number of action selections in between consecutive
      mini batch SGD updates
    - initial_exploration: initial epsilon value
    - final_exploration: final epsilon value
    - final_exploration_frame: number of frames over which the epsilon is
      annealed to its final value
    - lr: learning rate used by RMSprop
    - alpha: alpha value used by RMSprop
    - m: number of consecutive frames to stack for input to Q network

    Output:
    - Q: trained Q-network
    """
    n_actions = env.action_space.n
    Q = QNetwork(n_actions)
    Q_target = deepcopy(Q)
    Q_target.eval()

    transform = T.Compose([T.ToTensor()])
    optimizer = optim.RMSprop(Q.parameters(), lr=lr, alpha=alpha)
    criterion = nn.MSELoss()

    D = deque(maxlen=N)  # replay memory

    last_Q_target_update = 0
    frames_count = 0
    last_sgd_update = 0
    episodes_count = 0
    episode_rewards = []

    while True:
        frame_sequence = initialize_frame_sequence(env, m)
        state = transform(np.stack(frame_sequence, axis=2))

        episode_reward = 0
        done = False

        while not done:
            epsilon = annealed_epsilon(initial_exploration, final_exploration,
                                       final_exploration_frame, frames_count)

            action = get_epsilon_greedy_action(Q, state.unsqueeze(0), epsilon,
                                               n_actions)

            frame, reward, done, _ = env.step(action.item())
            reward = torch.tensor([reward])

            episode_reward += reward.item()
            if done:
                next_state = None
                episode_rewards.append(episode_reward)
            else:
                frame_sequence.append(preprocess_frame(frame))
                next_state = transform(np.stack(frame_sequence, axis=2))

            D.append((state, action, reward, next_state))

            state = next_state

            if len(D) < replay_start_size:
                continue

            last_sgd_update += 1
            if last_sgd_update < sgd_update_frequency:
                continue
            last_sgd_update = 0

            sgd_update(Q, Q_target, D, mini_batch_size, discount_factor,
                       optimizer, criterion)

            last_Q_target_update += 1
            frames_count += mini_batch_size

            if last_Q_target_update % C == 0:
                Q_target = deepcopy(Q)
                Q_target.eval()

            if frames_count >= nframes:
                return Q, episode_rewards

        episodes_count += 1
        if episodes_count % 100 == 0:
            save_stuff(Q, episode_rewards)
            print(f'episodes completed = {episodes_count},',
                  f'frames processed = {frames_count}')
Beispiel #5
0
class Agent:
    def __init__(self,
                 device,
                 state_size,
                 action_size,
                 buffer_size=10,
                 batch_size=10,
                 learning_rate=0.1,
                 discount_rate=0.99,
                 eps_decay=0.9,
                 tau=0.1,
                 steps_per_update=4):
        self.device = device
        self.state_size = state_size
        self.action_size = action_size

        self.q_network_control = QNetwork(state_size, action_size).to(device)
        self.q_network_target = QNetwork(state_size, action_size).to(device)
        self.optimizer = torch.optim.Adam(self.q_network_control.parameters(),
                                          lr=learning_rate)

        self.batch_size = batch_size
        self.replay_buffer = ReplayBuffer(device, state_size, action_size,
                                          buffer_size)

        self.discount_rate = discount_rate

        self.eps = 1.0
        self.eps_decay = eps_decay

        self.tau = tau

        self.step_count = 0
        self.steps_per_update = steps_per_update

    def policy(self, state):
        state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
        return self.epsilon_greedy_policy(self.eps, state)

    def epsilon_greedy_policy(self, eps, state):
        self.q_network_control.eval()
        with torch.no_grad():
            action_values = self.q_network_control(state)
        self.q_network_control.train()

        if random.random() > eps:
            greedy_choice = np.argmax(action_values.cpu().data.numpy())
            return greedy_choice
        else:
            return random.choice(np.arange(self.action_size))

    def step(self, state, action, reward, next_state, done):
        p = self.calculate_p(state, action, reward, next_state, done)
        self.replay_buffer.add(state, action, reward, next_state, done, p)
        if self.step_count % self.steps_per_update == 0:
            self.learn()
        self.step_count += 1

    def learn(self):
        if len(self.replay_buffer) < self.batch_size:
            return
        states, actions, rewards, next_states, dones, p = \
            self.replay_buffer.sample(self.batch_size)

        error = self.bellman_eqn_error(states, actions, rewards, next_states,
                                       dones)
        importance_scaling = (self.replay_buffer.buffer_size * p)**-1
        loss = (importance_scaling * (error**2)).sum() / self.batch_size
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.update_target()

    def bellman_eqn_error(self, states, actions, rewards, next_states, dones):
        """Double DQN error - use the control network to get the best action
        and apply the target network to it to get the target reward which is
        used for the bellman eqn error.
        """
        self.q_network_control.eval()
        with torch.no_grad():
            a_max = self.q_network_control(next_states).argmax(1).unsqueeze(1)

        target_action_values = self.q_network_target(next_states).gather(
            1, a_max)
        target_rewards = rewards + self.discount_rate * (1 - dones) \
                         * target_action_values

        self.q_network_control.train()
        current_rewards = self.q_network_control(states).gather(1, actions)
        error = current_rewards - target_rewards
        return error

    def calculate_p(self, state, action, reward, next_state, done):
        next_state = torch.from_numpy(next_state[np.newaxis, :]).float().to(
            self.device)
        state = torch.from_numpy(state[np.newaxis, :]).float().to(self.device)
        action = torch.from_numpy(np.array([[action]])).long().to(self.device)
        reward = torch.from_numpy(np.array([reward])).float().to(self.device)
        done = torch.from_numpy(np.array([[done]], dtype=np.uint8)).float().to(
            self.device)

        return abs(
            self.bellman_eqn_error(state, action, reward, next_state,
                                   done)) + 1e-3

    def update_target(self):
        for target_param, control_param in zip(
                self.q_network_target.parameters(),
                self.q_network_control.parameters()):
            target_param.data.copy_(self.tau * control_param.data +
                                    (1.0 - self.tau) * target_param.data)

    def end_of_episode(self):
        self.eps *= self.eps_decay
        self.step_count = 0

    def save(self, path):
        torch.save(self.q_network_control.state_dict(), path)

    def restore(self, path):
        self.q_network_control.load_state_dict(torch.load(path))
Beispiel #6
0
        ys.append(y)
    if num_batches > batch_size:
        remainder = num_batches % batch_size
        while remainder != 0:
            xs.append([0 for _ in range(seq_length)])
            ys.append([0 for _ in range(seq_length)])
            remainder -= 1
    xs = torch.Tensor(xs).unsqueeze(2).to(device)
    ys = torch.Tensor(ys).unsqueeze(2).to(device)
    print('x: {}'.format(xs[:, :, 0]))
    print('y: {}'.format(ys[:, :, 0]))
    return xs, ys

q = QNetwork(input_size, hidden_size, seq_length).to(device)
p = PNetwork(input_size, hidden_size, seq_length).to(device)
params = list(q.parameters()) + list(p.parameters())
optimizer = torch.optim.Adagrad(params, lr=lr)


def train(xs, ys):
    q.train()
    p.train()
    train_loss = 0
    for epoch in range(num_epochs):
        for b in range(0, num_batches, batch_size):
            
            z = q.sample(xs[b:b+batch_size], ys[b:b+batch_size], 0)
            z = z.detach()
            # print(z.size())
            # print("main: {}".format(z.size()))
            """ z : [batch_size x seq_length x (num_gates * hidden_size)] """
Beispiel #7
0
class Agent():
    def __init__(self, params):
        action_size = params['action_size']
        state_size = params['state_size']
        buf_params = params['buf_params']
        nn_params = params['nn_params']
        nn_params['l1'][0] = state_size
        nn_params['l5'][1] = action_size

        self.__learning_mode = params['learning_mode']

        if self.__learning_mode['DuelingDDQN']:
            self.__qnetwork_local = DuelingQNetwork(nn_params).to(device)
            self.__qnetwork_target = DuelingQNetwork(nn_params).to(device)
        else:
            self.__qnetwork_local = QNetwork(nn_params).to(device)
            self.__qnetwork_target = QNetwork(nn_params).to(device)

        self.__action_size = action_size
        self.__state_size = state_size
        self.__memory = ReplayBuffer(buf_params)
        self.__t = 0

        self.eps = params['eps_initial']
        self.gamma = params['gamma']
        self.learning_rate = params['learning_rate']
        self.update_period = params['update_period']
        self.a = params['a']
        self.b = params['b']
        self.e = params['e']
        self.tau = params['tau']

        self.__optimiser = optim.Adam(self.__qnetwork_local.parameters(),
                                      self.learning_rate)

        # other parameters
        self.agent_loss = 0.0

    # Set methods
    def set_learning_rate(self, lr):
        self.learning_rate = lr
        for param_group in self.__optimiser.param_groups:
            param_group['lr'] = lr

    # Get methods
    def get_qlocal(self):
        return self.__qnetwork_local

    # Other methods
    def step(self, state, action, reward, next_state, done):
        # add experience to memory
        self.__memory.add(state, action, reward, next_state, done)

        self.__t = (self.__t + 1) % self.update_period
        if not self.__t:
            if self.__memory.is_ready():
                experiences = self.__memory.sample()
                self.__update(experiences)

    def choose_action(self, state, mode='train'):
        # state should be transformed to a tensor
        if mode == 'train':
            if random.random() > self.eps:
                state = torch.from_numpy(state).float().unsqueeze(0).to(device)
                self.__qnetwork_local.eval()
                with torch.no_grad():
                    actions = self.__qnetwork_local(state)
                self.__qnetwork_local.train()
                return np.argmax(actions.cpu().data.numpy())
            else:
                return np.random.choice(np.arange(self.__action_size))
        elif mode == 'test':
            state = torch.from_numpy(state).float().unsqueeze(0).to(device)
            self.__qnetwork_local.eval()
            with torch.no_grad():
                actions = self.__qnetwork_local(state)
            self.__qnetwork_local.train()
            return np.argmax(actions.cpu().data.numpy())
        else:
            print("Invalid mode value")

    def __update(self, experiences):
        states, actions, rewards, next_states, dones, indices, probs = experiences
        # Compute and minimise the loss
        self.__optimiser.zero_grad()

        loss_fn = nn.MSELoss(reduce=False)

        if self.__learning_mode['DQN']:
            Q_target_next = self.__qnetwork_target.forward(next_states).max(
                1)[0].unsqueeze(1).detach()
        else:
            Q_target_next = self.__qnetwork_target.forward(next_states). \
                gather(1, self.__qnetwork_local.forward(next_states).max(1)[1].unsqueeze(1)).detach()

        targets = rewards + self.gamma * Q_target_next * (1 - dones)
        outputs = self.__qnetwork_local.forward(states).gather(1, actions)
        loss = loss_fn(outputs, targets)

        # Calculate weights and normalise
        if probs:
            weights = [(prob * len(self.__memory))**(-self.b)
                       for prob in probs]
            weights = np.array([w / max(weights) for w in weights]).reshape(
                (-1, 1))
        else:
            weights = np.ones(loss.shape, dtype=np.float)

        # Calculate weighted loss
        weighted_loss = torch.mean(torch.from_numpy(weights).float() * loss)
        weighted_loss.backward()

        self.__optimiser.step()

        if indices:
            self.__memory.update(
                indices,
                list(loss.detach().numpy().squeeze()**self.a + self.e))

        self.__soft_update(self.__qnetwork_local, self.__qnetwork_target,
                           self.tau)

        self.agent_loss = weighted_loss.detach().numpy().squeeze()

    def __soft_update(self, local_model, target_model, tau):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target
        """
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)
Beispiel #8
0
from q_network import QNetwork
import h5py
import argparse
import torch

if __name__ == "__main__":

    qNetwork = QNetwork(100, 64, 5, -1)
    rate = torch.optim.lr_scheduler.StepLR(optimizer=torch.optim.Adam(
        params=qNetwork.parameters(), lr=0.0005),
                                           step_size=250,
                                           gamma=0.9999)
    print(qNetwork)

    hiddenWidth1 = 100
    hiddenWidth2 = 64
    outputWidth = 5
    weightInit = -1
    batchSize = 4
    gamma = 0.7

    dataOut = h5py.File('skillWeightsQ.h5', 'w')

    print('Loading data...')
    # data = h5py.File(FLAGS.file, 'r')
    # numSkills = data.get('numberSkills')
    numSkills = 4
    print('Number of skills is ' + str(numSkills))

    dataOut.create_dataset('hiddenWidth', data=hiddenWidth1)
    dataOut.create_dataset('numberSkills', data=numSkills)
Beispiel #9
0
class Agent():
    def __init__(self, state_size, action_size, seed):
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(seed)

        # Q-Network
        self.qnetwork_local = QNetwork(state_size, action_size,
                                       seed).to(device)
        self.qnetwork_target = QNetwork(state_size, action_size,
                                        seed).to(device)
        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)

        # Replay memory
        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
        # Initialize time step (for updating every UPDATE_EVERY steps)
        self.t_step = 0

    def step(self, state, action, reward, next_state, done):
        # Save experience in replay memory
        self.memory.add(state, action, reward, next_state, done)

        # Learn every UPDATE_EVERY time steps.
        self.t_step = (self.t_step + 1) % UPDATE_EVERY
        if self.t_step == 0:
            # If enough samples are available in memory, get random subset and learn
            if len(self.memory) > BATCH_SIZE:
                experiences = self.memory.sample()
                self.learn(experiences, GAMMA)

    def act(self, state, eps=0.0):
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        self.qnetwork_local.eval()
        with torch.no_grad():
            action_values = self.qnetwork_local(state)
        self.qnetwork_local.train()

        # Epsilon-greedy action selection
        if random.random() > eps:
            return np.argmax(action_values.cpu().data.numpy())
        else:
            return random.choice(np.arange(self.action_size))

    def learn(self, experiences, gamma):
        states, actions, rewards, next_states, dones = experiences

        # Get argmax(QLocal) for action.
        # Get values from QTarget(action) for local indexes
        # Should be 64x1
        q_local_idx = self.qnetwork_local(next_states).detach().argmax(
            1).unsqueeze(1)
        Q_targets_next = self.qnetwork_target(next_states).detach().gather(
            1, q_local_idx)
        # Compute Q targets for current states
        Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))

        # Get expected Q values from local model
        Q_expected = self.qnetwork_local(states).gather(1, actions)

        # Compute loss
        loss = F.mse_loss(Q_expected, Q_targets)
        # Minimize the loss
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        # ------------------- update target network ------------------- #
        self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU)

    def soft_update(self, local_model, target_model, tau):
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)