Пример #1
0
 def _decode_loss(self, o_dec, o, a_dec, a, r_dec, r):
     # WARNING: not taking mean in the paper.
     # o : Image + velocity
     o_loss = F.mse_loss(o_dec, o)
     a_loss = self._bernoulli_softmax_crossentropy(a_dec, a)
     r_loss = F.mse_loss(r_dec, r) / 2
     return ALPHA_OBS*o_loss + ALPHA_ACTION*a_loss + ALPHA_REWARD*r_loss
Пример #2
0
def loss(anchors, data, pred, threshold):
    iou = pred['iou']
    device_id = iou.get_device() if torch.cuda.is_available() else None
    rows, cols = pred['feature'].size()[-2:]
    iou_matrix, _iou, _, _data = iou_match(pred['yx_min'].data, pred['yx_max'].data, data)
    anchors = utils.ensure_device(anchors, device_id)
    positive = fit_positive(rows, cols, *(data[key] for key in 'yx_min, yx_max'.split(', ')), anchors)
    negative = ~positive & (_iou < threshold)
    _center_offset, _size_norm = fill_norm(*(_data[key] for key in 'yx_min, yx_max'.split(', ')), anchors)
    positive, negative, _iou, _center_offset, _size_norm, _cls = (torch.autograd.Variable(t) for t in (positive, negative, _iou, _center_offset, _size_norm, _data['cls']))
    _positive = torch.unsqueeze(positive, -1)
    loss = {}
    # iou
    loss['foreground'] = F.mse_loss(iou[positive], _iou[positive], size_average=False)
    loss['background'] = torch.sum(square(iou[negative]))
    # bbox
    loss['center'] = F.mse_loss(pred['center_offset'][_positive], _center_offset[_positive], size_average=False)
    loss['size'] = F.mse_loss(pred['size_norm'][_positive], _size_norm[_positive], size_average=False)
    # cls
    if 'logits' in pred:
        logits = pred['logits']
        if len(_cls.size()) > 3:
            loss['cls'] = F.mse_loss(F.softmax(logits, -1)[_positive], _cls[_positive], size_average=False)
        else:
            loss['cls'] = F.cross_entropy(logits[_positive].view(-1, logits.size(-1)), _cls[positive].view(-1))
    # normalize
    cnt = float(np.multiply.reduce(positive.size()))
    for key in loss:
        loss[key] /= cnt
    return loss, dict(iou=_iou, data=_data, positive=positive, negative=negative)
Пример #3
0
def my_loss_function(reconstructed_x, z_patches, prototypes, padding_idx, x, lambda_=0.01):
    """
        reconstructed_x : batch_size, channels, height, width
        z_patches       : batch_size, num_patches, embedding_dim
        prototypes      : batch_size, num_prototypes, embedding_dim
        padding_idx     : batch_size
        x               : batch_size, channels, height, width
    """
    assert not x.requires_grad

    batch_size = x.size(0)

    loss = F.mse_loss(reconstructed_x, x, size_average=False)
    for i in range(batch_size):
        image_patches = z_patches[i]
        image_prototypes = prototypes[i][:padding_idx[i]]

        dists = pairwise_squared_euclidean_distances(
            image_prototypes, image_patches)
        min_dists = torch.min(dists, dim=1)[0]

        prototype_loss = torch.sum(min_dists)
        loss = loss + (lambda_ * prototype_loss)

    loss = loss / batch_size
    return loss
Пример #4
0
    def update_parameters(self, batch):
        state_batch = Variable(torch.cat(batch.state))
        action_batch = Variable(torch.cat(batch.action))
        reward_batch = Variable(torch.cat(batch.reward))
        mask_batch = Variable(torch.cat(batch.mask))
        next_state_batch = Variable(torch.cat(batch.next_state))
        
        next_action_batch = self.actor_target(next_state_batch)
        next_state_action_values = self.critic_target(next_state_batch, next_action_batch)

        reward_batch = reward_batch.unsqueeze(1)
        mask_batch = mask_batch.unsqueeze(1)
        expected_state_action_batch = reward_batch + (self.gamma * mask_batch * next_state_action_values)

        self.critic_optim.zero_grad()

        state_action_batch = self.critic((state_batch), (action_batch))

        value_loss = F.mse_loss(state_action_batch, expected_state_action_batch)
        value_loss.backward()
        self.critic_optim.step()

        self.actor_optim.zero_grad()

        policy_loss = -self.critic((state_batch),self.actor((state_batch)))

        policy_loss = policy_loss.mean()
        policy_loss.backward()
        self.actor_optim.step()

        soft_update(self.actor_target, self.actor, self.tau)
        soft_update(self.critic_target, self.critic, self.tau)

        return value_loss.item(), policy_loss.item()
Пример #5
0
def get_loss(latent_obs, action, reward, terminal, latent_next_obs):
    """ Compute losses.

    The loss that is computed is:
    (GMMLoss(latent_next_obs, GMMPredicted) + MSE(reward, predicted_reward) +
         BCE(terminal, logit_terminal)) / (LSIZE + 2)
    The LSIZE + 2 factor is here to counteract the fact that the GMMLoss scales
    approximately linearily with LSIZE. All losses are averaged both on the
    batch and the sequence dimensions (the two first dimensions).

    :args latent_obs: (BSIZE, SEQ_LEN, LSIZE) torch tensor
    :args action: (BSIZE, SEQ_LEN, ASIZE) torch tensor
    :args reward: (BSIZE, SEQ_LEN) torch tensor
    :args latent_next_obs: (BSIZE, SEQ_LEN, LSIZE) torch tensor

    :returns: dictionary of losses, containing the gmm, the mse, the bce and
        the averaged loss.
    """
    latent_obs, action,\
        reward, terminal,\
        latent_next_obs = [arr.transpose(1, 0)
                           for arr in [latent_obs, action,
                                       reward, terminal,
                                       latent_next_obs]]
    mus, sigmas, logpi, rs, ds = mdrnn(action, latent_obs)
    gmm = gmm_loss(latent_next_obs, mus, sigmas, logpi)
    bce = f.binary_cross_entropy_with_logits(ds, terminal)
    mse = f.mse_loss(rs, reward)
    loss = (gmm + bce + mse) / (LSIZE + 2)
    return dict(gmm=gmm, bce=bce, mse=mse, loss=loss)
def train_a2c(net, mb_obs, mb_rewards, mb_actions, mb_values, optimizer, tb_tracker, step_idx, device="cpu"):
    optimizer.zero_grad()
    mb_adv = mb_rewards - mb_values
    adv_v = torch.FloatTensor(mb_adv).to(device)
    obs_v = torch.FloatTensor(mb_obs).to(device)
    rewards_v = torch.FloatTensor(mb_rewards).to(device)
    actions_t = torch.LongTensor(mb_actions).to(device)
    logits_v, values_v = net(obs_v)
    log_prob_v = F.log_softmax(logits_v, dim=1)
    log_prob_actions_v = adv_v * log_prob_v[range(len(mb_actions)), actions_t]

    loss_policy_v = -log_prob_actions_v.mean()
    loss_value_v = F.mse_loss(values_v.squeeze(-1), rewards_v)

    prob_v = F.softmax(logits_v, dim=1)
    entropy_loss_v = (prob_v * log_prob_v).sum(dim=1).mean()
    loss_v = ENTROPY_BETA * entropy_loss_v + VALUE_LOSS_COEF * loss_value_v + loss_policy_v
    loss_v.backward()
    nn_utils.clip_grad_norm_(net.parameters(), CLIP_GRAD)
    optimizer.step()

    tb_tracker.track("advantage", mb_adv, step_idx)
    tb_tracker.track("values", values_v, step_idx)
    tb_tracker.track("batch_rewards", rewards_v, step_idx)
    tb_tracker.track("loss_entropy", entropy_loss_v, step_idx)
    tb_tracker.track("loss_policy", loss_policy_v, step_idx)
    tb_tracker.track("loss_value", loss_value_v, step_idx)
    tb_tracker.track("loss_total", loss_v, step_idx)
    return obs_v
def grad_fun(net, queue):
    iter_idx = 0
    while True:
        sum_loss = 0.0
        iter_idx += 1
        for v in TRAIN_DATA:
            x_v = Variable(torch.from_numpy(np.array([v], dtype=np.float32)))
            y_v = Variable(torch.from_numpy(np.array([get_y(v)], dtype=np.float32)))
            if CUDA:
                x_v = x_v.cuda()
                y_v = y_v.cuda()

            net.zero_grad()
            out_v = net(x_v)
            loss_v = F.mse_loss(out_v, y_v)
            loss_v.backward()

            grads = [param.grad.clone() if param.grad is not None else None
                     for param in net.parameters()]

            queue.put(grads)
            sum_loss += loss_v.data.cpu().numpy()
        print("%d: %.2f" % (iter_idx, sum_loss))
        if sum_loss < 0.1:
            queue.put(None)
            break
Пример #8
0
def train(epoch):
    global lr
    model.train()
    batch_idx = 1
    total_loss = 0
    for i in range(0, X_train.size()[0], batch_size):
        if i + batch_size > X_train.size()[0]:
            x, y = X_train[i:], Y_train[i:]
        else:
            x, y = X_train[i:(i+batch_size)], Y_train[i:(i+batch_size)]
        optimizer.zero_grad()
        output = model(x)
        loss = F.mse_loss(output, y)
        loss.backward()
        if args.clip > 0:
            torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
        optimizer.step()
        batch_idx += 1
        total_loss += loss.data[0]

        if batch_idx % args.log_interval == 0:
            cur_loss = total_loss / args.log_interval
            processed = min(i+batch_size, X_train.size()[0])
            print('Train Epoch: {:2d} [{:6d}/{:6d} ({:.0f}%)]\tLearning rate: {:.4f}\tLoss: {:.6f}'.format(
                epoch, processed, X_train.size()[0], 100.*processed/X_train.size()[0], lr, cur_loss))
            total_loss = 0
    def train_step(self, state_batch, mcts_probs, winner_batch, lr):
        """perform a training step"""
        # wrap in Variable
        if self.use_gpu:
            state_batch = Variable(torch.FloatTensor(state_batch).cuda())
            mcts_probs = Variable(torch.FloatTensor(mcts_probs).cuda())
            winner_batch = Variable(torch.FloatTensor(winner_batch).cuda())
        else:
            state_batch = Variable(torch.FloatTensor(state_batch))
            mcts_probs = Variable(torch.FloatTensor(mcts_probs))
            winner_batch = Variable(torch.FloatTensor(winner_batch))

        # zero the parameter gradients
        self.optimizer.zero_grad()
        # set learning rate
        set_learning_rate(self.optimizer, lr)

        # forward
        log_act_probs, value = self.policy_value_net(state_batch)
        # define the loss = (z - v)^2 - pi^T * log(p) + c||theta||^2
        # Note: the L2 penalty is incorporated in optimizer
        value_loss = F.mse_loss(value.view(-1), winner_batch)
        policy_loss = -torch.mean(torch.sum(mcts_probs*log_act_probs, 1))
        loss = value_loss + policy_loss
        # backward and optimize
        loss.backward()
        self.optimizer.step()
        # calc policy entropy, for monitoring only
        entropy = -torch.mean(
                torch.sum(torch.exp(log_act_probs) * log_act_probs, 1)
                )
        return loss.data[0], entropy.data[0]
Пример #10
0
    def learn(self, experiences, gamma):
        """Update value parameters using given batch of experience tuples.

        Params
        ======
            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples
            gamma (float): discount factor
        """
        states, actions, rewards, next_states, dones = experiences

        # Get max predicted Q values (for next states) from target model
        Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(1)
        # Compute Q targets for current states
        Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))

        # Get expected Q values from local model
        Q_expected = self.qnetwork_local(states).gather(1, actions)

        # Compute loss
        loss = F.mse_loss(Q_expected, Q_targets)
        # Minimize the loss
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # ------------------- update target network ------------------- #
        self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU)
Пример #11
0
 def forward(self, input, target):
     output = F.mse_loss(input, target, reduction=self.reduction)
     l2_loss = sum(param.norm(2)**2 for param in self.model.parameters())
     output += self.l2 / 2 * l2_loss
     l1_loss = sum(param.norm(1) for param in self.model.parameters())
     output += self.l1 * l1_loss
     return output
Пример #12
0
    def loss_function(self, recon_x, x, mu, logvar):
        # BCE = F.binary_cross_entropy(recon_x, x.view(-1, self.input_size), size_average=False)
        BCE = F.mse_loss(recon_x, x.view(-1, self.input_size), size_average=False)

        # see Appendix B from VAE paper:
        # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
        # https://arxiv.org/abs/1312.6114
        # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

        return BCE + KLD
Пример #13
0
def softmax_mse_loss(input_logits, target_logits):
    """Takes softmax on both sides and returns MSE loss

    Note:
    - Returns the sum over all examples. Divide by the batch size afterwards
      if you want the mean.
    - Sends gradients to inputs but not the targets.
    """
    assert input_logits.size() == target_logits.size()
    input_softmax = F.softmax(input_logits, dim=1)
    target_softmax = F.softmax(target_logits, dim=1)
    num_classes = input_logits.size()[1]
    return F.mse_loss(input_softmax, target_softmax, size_average=False) / num_classes
Пример #14
0
    def forward(self, task=None, input1=None, input2=None, label=None):
        '''
        Predict through model and task-specific prediction layer

        Args:
            - inputs (tuple(TODO))
            - pred_layer (nn.Module)
            - pair_input (int)

        Returns:
            - logits (TODO)
        '''
        pair_input = task.pair_input
        pred_layer = getattr(self, '%s_pred_layer' % task.name)
        if pair_input:
            if self.pair_enc_type == 'bow':
                sent1 = self.sent_encoder(input1)
                sent2 = self.sent_encoder(input2) # causes a bug with BiDAF
                logits = pred_layer(torch.cat([sent1, sent2, torch.abs(sent1 - sent2),
                                               sent1 * sent2], 1))
            else:
                pair_emb = self.pair_encoder(input1, input2)
                logits = pred_layer(pair_emb)

        else:
            sent_emb = self.sent_encoder(input1)
            logits = pred_layer(sent_emb)
        out = {'logits': logits}
        if label is not None:
            if isinstance(task, (STS14Task, STSBTask)):
                loss = F.mse_loss(logits, label)
                label = label.squeeze(-1).data.cpu().numpy()
                logits = logits.squeeze(-1).data.cpu().numpy()
                task.scorer1(pearsonr(logits, label)[0])
                task.scorer2(spearmanr(logits, label)[0])
            elif isinstance(task, CoLATask):
                label = label.squeeze(-1)
                loss = F.cross_entropy(logits, label)
                task.scorer2(logits, label)
                label = label.data.cpu().numpy()
                _, preds = logits.max(dim=1)
                task.scorer1(matthews_corrcoef(label, preds.data.cpu().numpy()))
            else:
                label = label.squeeze(-1)
                loss = F.cross_entropy(logits, label)
                task.scorer1(logits, label)
                if task.scorer2 is not None:
                    task.scorer2(logits, label)
            out['loss'] = loss
        return out
Пример #15
0
 def forward(self, observations, actions, advantages, value_targets):
     logits, _, values, _ = self.policy_model({"obs": observations}, [])
     log_probs = F.log_softmax(logits, dim=1)
     probs = F.softmax(logits, dim=1)
     action_log_probs = log_probs.gather(1, actions.view(-1, 1))
     entropy = -(log_probs * probs).sum(-1).sum()
     pi_err = -advantages.dot(action_log_probs.reshape(-1))
     value_err = F.mse_loss(values.reshape(-1), value_targets)
     overall_err = sum([
         pi_err,
         self.vf_loss_coeff * value_err,
         self.entropy_coeff * entropy,
     ])
     return overall_err
Пример #16
0
 def closure():
   global step, final_loss
   optimizer.zero_grad()
   output = model(data)
   loss = F.mse_loss(output, data)
   if verbose:
     loss0 = loss.data[0]
     times.append(u.last_time())
     print("Step %3d loss %6.5f msec %6.3f"%(step, loss0, u.last_time()))
   step+=1
   if step == iters:
     final_loss = loss.data[0]
   loss.backward()
   u.record_time()
   return loss
Пример #17
0
 def learn(self, optimizer, history_of_rewards, gamma):
     total_weighted_reward=0
     gradient=Variable(torch.zeros(1,1))
     loss=0
     history_of_total_weighted_reward=[]
     for i in reversed(range(len(history_of_rewards))):
         total_weighted_reward=gamma*total_weighted_reward+rewards[i]
         history_of_total_weighted_reward.insert(0,total_weighted_reward)
     history_of_total_weighted_reward=torch.tensor(history_of_total_weighted_reward)
     #rescale the reward value(do not want to compute raw Q value)
     reward_u=history_of_total_weighted_reward.mean()
     reward_std=history_of_total_weighted_reward.std()+1e-8
     history_of_total_weighted_reward=(history_of_total_weighted_reward-reward_u)/reward_std
     for i in range(len(self.history_of_values)):
         loss+=F.mse_loss(history_of_values[i], history_of_weighted_reward[i])
     optimizer.zero_grad()
     loss.backward()
     optimizer.step()
     self.history_of_values=[]
    def learn(self, experiences, gamma):
        """Update policy and value parameters using given batch of experience tuples.
        Q_targets = r + γ * critic_target(next_state, actor_target(next_state))
        where:
            actor_target(state) -> action
            critic_target(state, action) -> Q-value

        Params
        ======
            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples 
            gamma (float): discount factor
        """
        states, actions, rewards, next_states, dones = experiences

        # ---------------------------- update critic ---------------------------- #
        # Get predicted next-state actions and Q values from target models
        actions_next = self.actor_target(next_states)
        Q_targets_next = self.critic_target(next_states, actions_next)
        # Compute Q targets for current states (y_i)
        Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))
        # Compute critic loss
        Q_expected = self.critic_local(states, actions)
        critic_loss = F.mse_loss(Q_expected, Q_targets)
        # Minimize the loss
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # ---------------------------- update actor ---------------------------- #
        # Compute actor loss
        actions_pred = self.actor_local(states)
        actor_loss = -self.critic_local(states, actions_pred).mean()
        # Minimize the loss
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # ----------------------- update target networks ----------------------- #
        self.soft_update(self.critic_local, self.critic_target, TAU)
        self.soft_update(self.actor_local, self.actor_target, TAU)                     
Пример #19
0
 def forward(self, input):
     G = gram_matrix(input)
     self.loss = F.mse_loss(G, self.target)
     return input
def mse(output, target, squeeze=False):
    if squeeze:
        target = target.squeeze(1)
    return F.mse_loss(output, target)
Пример #21
0
 def step_model(model, input, target):
     model.train()
     output = model(input)
     loss = F.mse_loss(output, target)
     loss.backward()
Пример #22
0
def ddpg(env_fn,
         actor_critic=core.ActorCritic,
         ac_kwargs=dict(),
         seed=0,
         steps_per_epoch=5000,
         epochs=100,
         replay_size=int(1e6),
         gamma=0.99,
         polyak=0.995,
         pi_lr=1e-3,
         q_lr=1e-3,
         batch_size=100,
         start_steps=10000,
         act_noise=0.1,
         max_ep_len=1000,
         logger_kwargs=dict(),
         save_freq=1):
    """

    Args:
        env_fn : A function which creates a copy of the environment.
            The environment must satisfy the OpenAI Gym API.
            
        actor_critic: The agent's main model which takes some states ``x`` and 
            and actions ``a`` and returns a tuple of:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``pi``       (batch, act_dim)  | Deterministically computes actions
                                           | from policy given states.
            ``q``        (batch,)          | Gives the current estimate of Q* for
                                           | states ``x`` and actions in
                                           | ``a``.
            ``q_pi``     (batch,)          | Gives the composition of ``q`` and
                                           | ``pi`` for states in ``x``:
                                           | q(x, pi(x)).
            ===========  ================  ======================================

        ac_kwargs (dict): Any kwargs appropriate for the actor_critic
            class you provided to DDPG.

        seed (int): Seed for random number generators.

        steps_per_epoch (int): Number of steps of interaction (state-action pairs)
            for the agent and the environment in each epoch.

        epochs (int): Number of epochs to run and train agent.

        replay_size (int): Maximum length of replay buffer.

        gamma (float): Discount factor. (Always between 0 and 1.)

        polyak (float): Interpolation factor in polyak averaging for target
            networks. Target networks are updated towards main networks
            according to:

            .. math:: \\theta_{\\text{targ}} \\leftarrow
                \\rho \\theta_{\\text{targ}} + (1-\\rho) \\theta

            where :math:`\\rho` is polyak. (Always between 0 and 1, usually
            close to 1.)

        pi_lr (float): Learning rate for policy.

        q_lr (float): Learning rate for Q-networks.

        batch_size (int): Minibatch size for SGD.

        start_steps (int): Number of steps for uniform-random action selection,
            before running real policy. Helps exploration.

        act_noise (float): Stddev for Gaussian exploration noise added to
            policy at training time. (At test time, no noise is added.)

        max_ep_len (int): Maximum length of trajectory / episode / rollout.

        logger_kwargs (dict): Keyword args for EpochLogger.

        save_freq (int): How often (in terms of gap between epochs) to save
            the current policy and value function.

    """

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    torch.manual_seed(seed)
    np.random.seed(seed)

    env, test_env = env_fn(), env_fn()
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]

    # Action limit for clamping: critically, assumes all dimensions share the same bound!
    act_limit = env.action_space.high[0]

    # Share information about action space with policy architecture
    ac_kwargs['action_space'] = env.action_space

    # Main outputs from computation graph
    main = actor_critic(in_features=obs_dim, **ac_kwargs)

    # Target networks
    target = actor_critic(in_features=obs_dim, **ac_kwargs)
    target.eval()

    # Experience buffer
    replay_buffer = ReplayBuffer(obs_dim=obs_dim,
                                 act_dim=act_dim,
                                 size=replay_size)

    # Count variables
    var_counts = tuple(
        core.count_vars(module) for module in [main.policy, main.q, main])
    print('\nNumber of parameters: \t pi: %d, \t q: %d, \t total: %d\n' %
          var_counts)

    # Separate train ops for pi, q
    pi_optimizer = torch.optim.Adam(main.policy.parameters(), lr=pi_lr)
    q_optimizer = torch.optim.Adam(main.q.parameters(), lr=q_lr)

    # Initializing targets to match main variables
    target.load_state_dict(main.state_dict())

    def get_action(o, noise_scale):
        pi = main.policy(torch.Tensor(o.reshape(1, -1)))
        a = pi.data.numpy()[0] + noise_scale * np.random.randn(act_dim)
        return np.clip(a, -act_limit, act_limit)

    def test_agent(n=10):
        for _ in range(n):
            o, r, d, ep_ret, ep_len = test_env.reset(), 0, False, 0, 0
            while not (d or (ep_len == max_ep_len)):
                # Take deterministic actions at test time (noise_scale=0)
                o, r, d, _ = test_env.step(get_action(o, 0))
                ep_ret += r
                ep_len += 1
            logger.store(TestEpRet=ep_ret, TestEpLen=ep_len)

    start_time = time.time()
    o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
    total_steps = steps_per_epoch * epochs

    # Main loop: collect experience in env and update/log each epoch
    for t in range(total_steps):
        main.eval()
        """
        Until start_steps have elapsed, randomly sample actions
        from a uniform distribution for better exploration. Afterwards,
        use the learned policy (with some noise, via act_noise).
        """
        if t > start_steps:
            a = get_action(o, act_noise)
        else:
            a = env.action_space.sample()

        # Step the env
        o2, r, d, _ = env.step(a)
        ep_ret += r
        ep_len += 1

        # Ignore the "done" signal if it comes from hitting the time
        # horizon (that is, when it's an artificial terminal signal
        # that isn't based on the agent's state)
        d = False if ep_len == max_ep_len else d

        # Store experience to replay buffer
        replay_buffer.store(o, a, r, o2, d)

        # Super critical, easy to overlook step: make sure to update
        # most recent observation!
        o = o2

        if d or (ep_len == max_ep_len):
            main.train()
            """
            Perform all DDPG updates at the end of the trajectory,
            in accordance with tuning done by TD3 paper authors.
            """
            for _ in range(ep_len):
                batch = replay_buffer.sample_batch(batch_size)
                (obs1, obs2, acts, rews, done) = (torch.Tensor(batch['obs1']),
                                                  torch.Tensor(batch['obs2']),
                                                  torch.Tensor(batch['acts']),
                                                  torch.Tensor(batch['rews']),
                                                  torch.Tensor(batch['done']))
                _, q, q_pi = main(obs1, acts)
                _, _, q_pi_targ = target(obs2, acts)

                # Bellman backup for Q function
                backup = (rews + gamma * (1 - done) * q_pi_targ).detach()

                # DDPG losses
                pi_loss = -q_pi.mean()
                q_loss = F.mse_loss(q, backup)

                # Q-learning update
                q_optimizer.zero_grad()
                q_loss.backward()
                q_optimizer.step()
                logger.store(LossQ=q_loss, QVals=q.data.numpy())

                # Policy update
                pi_optimizer.zero_grad()
                pi_loss.backward()
                pi_optimizer.step()
                logger.store(LossPi=pi_loss)

                # Polyak averaging for target parameters
                for p_main, p_target in zip(main.parameters(),
                                            target.parameters()):
                    p_target.data.copy_(polyak * p_target.data +
                                        (1 - polyak) * p_main.data)

            logger.store(EpRet=ep_ret, EpLen=ep_len)
            o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0

        # End of epoch wrap-up
        if t > 0 and t % steps_per_epoch == 0:
            epoch = t // steps_per_epoch

            # Save model
            if (epoch % save_freq == 0) or (epoch == epochs - 1):
                logger.save_state({'env': env}, main, None)

            # Test the performance of the deterministic version of the agent.
            test_agent()

            # Log info about epoch
            logger.log_tabular('Epoch', epoch)
            logger.log_tabular('EpRet', with_min_and_max=True)
            logger.log_tabular('TestEpRet', with_min_and_max=True)
            logger.log_tabular('EpLen', average_only=True)
            logger.log_tabular('TestEpLen', average_only=True)
            logger.log_tabular('TotalEnvInteracts', t)
            logger.log_tabular('QVals', with_min_and_max=True)
            logger.log_tabular('LossPi', average_only=True)
            logger.log_tabular('LossQ', average_only=True)
            logger.log_tabular('Time', time.time() - start_time)
            logger.dump_tabular()
Пример #23
0
def test_loop(cfg, model, criterion, test_loader, epoch):
    model.eval()
    model = model.to(cfg.device)
    total_psnr = 0
    total_loss = 0
    psnrs = np.zeros((len(test_loader), cfg.batch_size))
    use_record = False
    record_test = pd.DataFrame({'psnr': []})

    with torch.no_grad():
        for batch_idx, (burst, res_imgs, raw_img,
                        directions) in enumerate(test_loader):
            B = raw_img.shape[0]

            # -- selecting input frames --
            input_order = np.arange(cfg.N)
            # print("pre",input_order)
            middle_img_idx = -1
            if not cfg.input_with_middle_frame:
                middle = cfg.N // 2
                # print(middle)
                middle_img_idx = input_order[middle]
                # input_order = np.r_[input_order[:middle],input_order[middle+1:]]
            else:
                middle = len(input_order) // 2
                input_order = np.arange(cfg.N)
                middle_img_idx = input_order[middle]
                # input_order = np.arange(cfg.N)

            # -- reshaping of data --
            raw_img = raw_img.cuda(non_blocking=True)
            burst = burst.cuda(non_blocking=True)
            stacked_burst = torch.stack(
                [burst[input_order[x]] for x in range(cfg.input_N)], dim=1)
            cat_burst = torch.cat(
                [burst[input_order[x]] for x in range(cfg.input_N)], dim=1)

            # -- denoising --
            aligned, aligned_ave, denoised, denoised_ave, a_filters, d_filters = model(
                burst)
            denoised_ave = denoised_ave.detach()

            # if not cfg.input_with_middle_frame:
            #     denoised_ave = model(cat_burst,stacked_burst)[1]
            # else:
            #     denoised_ave = model(cat_burst,stacked_burst)[0][middle_img_idx]

            # denoised_ave = burst[middle_img_idx] - rec_res

            # -- compare with stacked targets --
            denoised_ave = rescale_noisy_image(denoised_ave)

            # -- compute psnr --
            loss = F.mse_loss(raw_img, denoised_ave,
                              reduction='none').reshape(B, -1)
            # loss = F.mse_loss(raw_img,burst[cfg.input_N//2]+0.5,reduction='none').reshape(B,-1)
            loss = torch.mean(loss, 1).detach().cpu().numpy()
            psnr = mse_to_psnr(loss)
            psnrs[batch_idx, :] = psnr

            if use_record:
                record_test = record_test.append({'psnr': psnr},
                                                 ignore_index=True)
            total_psnr += np.mean(psnr)
            total_loss += np.mean(loss)

            # if (batch_idx % cfg.test_log_interval) == 0:
            #     root = Path(f"{settings.ROOT_PATH}/output/n2n/offset_out_noise/denoised_aves/N{cfg.N}/e{epoch}")
            #     if not root.exists(): root.mkdir(parents=True)
            #     fn = root / Path(f"b{batch_idx}.png")
            #     nrow = int(np.sqrt(cfg.batch_size))
            #     denoised_ave = denoised_ave.detach().cpu()
            #     grid_imgs = tv_utils.make_grid(denoised_ave, padding=2, normalize=True, nrow=nrow)
            #     plt.imshow(grid_imgs.permute(1,2,0))
            #     plt.savefig(fn)
            #     plt.close('all')
            if batch_idx % 100 == 0:
                print("[%d/%d] Test PSNR: %2.2f" %
                      (batch_idx, len(test_loader), total_psnr /
                       (batch_idx + 1)))

    psnr_ave = np.mean(psnrs)
    psnr_std = np.std(psnrs)
    ave_loss = total_loss / len(test_loader)
    print("[N: %d] Testing: [psnr: %2.2f +/- %2.2f] [ave loss %2.3e]" %
          (cfg.N, psnr_ave, psnr_std, ave_loss))
    return psnr_ave, record_test
            sum_loss = 0.0
            sum_value_loss = 0.0
            sum_policy_loss = 0.0

            for _ in range(TRAIN_ROUNDS):
                batch = random.sample(replay_buffer, BATCH_SIZE)
                batch_states, batch_who_moves, batch_probs, batch_values = zip(*batch)
                batch_states_lists = [game.decode_binary(state) for state in batch_states]
                states_v = model.state_lists_to_batch(batch_states_lists, batch_who_moves, device)

                optimizer.zero_grad()
                probs_v = torch.FloatTensor(batch_probs).to(device)
                values_v = torch.FloatTensor(batch_values).to(device)
                out_logits_v, out_values_v = net(states_v)

                loss_value_v = F.mse_loss(out_values_v.squeeze(-1), values_v)
                loss_policy_v = -F.log_softmax(out_logits_v, dim=1) * probs_v
                loss_policy_v = loss_policy_v.sum(dim=1).mean()

                loss_v = loss_policy_v + loss_value_v
                loss_v.backward()
                optimizer.step()
                sum_loss += loss_v.item()
                sum_value_loss += loss_value_v.item()
                sum_policy_loss += loss_policy_v.item()

            tb_tracker.track("loss_total", sum_loss / TRAIN_ROUNDS, step_idx)
            tb_tracker.track("loss_value", sum_value_loss / TRAIN_ROUNDS, step_idx)
            tb_tracker.track("loss_policy", sum_policy_loss / TRAIN_ROUNDS, step_idx)

            # evaluate net
                # handle new rewards
                new_rewards = exp_source.pop_total_rewards()
                if new_rewards:
                    if tracker.reward(new_rewards[0], step_idx):
                        break

                if len(batch) < BATCH_SIZE:
                    continue

                states_v, actions_t, vals_ref_v = unpack_batch(batch, net, device=device)
                batch.clear()

                optimizer.zero_grad()
                logits_v, value_v = net(states_v)
                loss_value_v = F.mse_loss(value_v.squeeze(-1), vals_ref_v)

                log_prob_v = F.log_softmax(logits_v, dim=1)
                adv_v = vals_ref_v - value_v.detach()
                log_prob_actions_v = adv_v * log_prob_v[range(BATCH_SIZE), actions_t]
                loss_policy_v = -log_prob_actions_v.mean()

                prob_v = F.softmax(logits_v, dim=1)
                entropy_loss_v = ENTROPY_BETA * (prob_v * log_prob_v).sum(dim=1).mean()

                # calculate policy gradients only
                loss_policy_v.backward(retain_graph=True)
                grads = np.concatenate([p.grad.data.cpu().numpy().flatten()
                                        for p in net.parameters()
                                        if p.grad is not None])
Пример #26
0
        num_layers=1,
        dropout=0.0
    )
    
    # Training loop:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    fx_encoder.to(device)
    fx_decoder.to(device)
    params = list(fx_encoder.parameters()) + list(fx_decoder.parameters())
    optimizer = optim.Adam(params, lr=1e-3)
    for epoch in range(3):
        for i, (x, y) in enumerate(trainloader):
            x, y = x.to(device), y.to(device)
            _, h = fx_encoder(x)
            o, _ = fx_decoder(h, 32, y, True)
            loss = F.mse_loss(o, y)
            loss.backward()
            optimizer.step()

            if i % 100 == 0:
                print("[%3d - %3d] train_loss: %.4f" % (epoch, i, loss.item()))
    
    # Save model:
    if os.path.isdir("./checkpoint"):
        shutil.rmtree("./checkpoint/")
    os.mkdir("./checkpoint/")
    joblib.dump(fx_dm.scaler, "./checkpoint/scaler.save")
    torch.save(fx_encoder.state_dict(), "./checkpoint/fx_encoder.pth")
    torch.save(fx_decoder.state_dict(), "./checkpoint/fx_decoder.pth")

    
Пример #27
0
 def training_step(self, batch, batch_nb):
     x, y = batch
     y_hat = self(x)
     loss = F.mse_loss(y_hat, y)
     return {'loss': loss}
Пример #28
0
    def learn(  # type: ignore
            self, batch: Batch, batch_size: int, repeat: int,
            **kwargs: Any) -> Dict[str, List[float]]:
        actor_losses, vf_losses, kls = [], [], []
        for step in range(repeat):
            for b in batch.split(batch_size, merge_last=True):
                # optimize actor
                # direction: calculate villia gradient
                dist = self(b).dist
                log_prob = dist.log_prob(b.act)
                log_prob = log_prob.reshape(log_prob.size(0),
                                            -1).transpose(0, 1)
                actor_loss = -(log_prob * b.adv).mean()
                flat_grads = self._get_flat_grad(actor_loss,
                                                 self.actor,
                                                 retain_graph=True).detach()

                # direction: calculate natural gradient
                with torch.no_grad():
                    old_dist = self(b).dist

                kl = kl_divergence(old_dist, dist).mean()
                # calculate first order gradient of kl with respect to theta
                flat_kl_grad = self._get_flat_grad(kl,
                                                   self.actor,
                                                   create_graph=True)
                search_direction = -self._conjugate_gradients(
                    flat_grads, flat_kl_grad, nsteps=10)

                # step
                with torch.no_grad():
                    flat_params = torch.cat([
                        param.data.view(-1)
                        for param in self.actor.parameters()
                    ])
                    new_flat_params = flat_params + self._step_size * search_direction
                    self._set_from_flat_params(self.actor, new_flat_params)
                    new_dist = self(b).dist
                    kl = kl_divergence(old_dist, new_dist).mean()

                # optimize citirc
                for _ in range(self._optim_critic_iters):
                    value = self.critic(b.obs).flatten()
                    vf_loss = F.mse_loss(b.returns, value)
                    self.optim.zero_grad()
                    vf_loss.backward()
                    self.optim.step()

                actor_losses.append(actor_loss.item())
                vf_losses.append(vf_loss.item())
                kls.append(kl.item())

        # update learning rate if lr_scheduler is given
        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return {
            "loss/actor": actor_losses,
            "loss/vf": vf_losses,
            "kl": kls,
        }
Пример #29
0
def rmse_loss(pred, targ):
    denom = targ**2
    denom = torch.sqrt(denom.sum() / len(denom))
    return torch.sqrt(F.mse_loss(pred, targ)) / denom
# pytorch_learning_12_perceptron

import torch
from torch.nn import functional as F

# 多输入单输出
x = torch.randn(1, 10)
w = torch.randn(1, 10, requires_grad=True)

o = torch.sigmoid(x @ w.t())
print(o)  # tensor([[0.4561]], grad_fn=<SigmoidBackward>)
print(o.shape)  # torch.Size([1, 1])

loss = F.mse_loss(torch.ones(1, 1), o)
print(loss.shape)  # torch.Size([])

loss.backward()
print(w.grad)
# tensor([[ 0.3117, -0.2556,  0.0821,  0.3212, -0.1595,  0.1250,  0.0498,  0.0400,
#          -0.1263,  0.0531]])
print("----------------------------")

# 多输入多输出
x = torch.randn(1, 10)
w = torch.randn(2, 10, requires_grad=True)
o = torch.sigmoid(x @ w.t())

print(o.shape)
# torch.Size([1, 2])
loss = F.mse_loss(torch.ones(1, 2), o)  # (1, 1)也可以,自动broadcast
print(loss)
def train(trainset):

    # Split dataset
    train_size = int(args.train_percentage * len(trainset))
    test_size = len(trainset) - train_size
    train_dataset, test_dataset \
        = torch.utils.data.random_split(trainset, [train_size, test_size])

    # Dataset information
    print('train dataset : {} elements'.format(len(train_dataset)))

    # Create dataset loader
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True)

    # Show sample of images
    if args.plot:
        # get some random training images
        dataiter = iter(train_loader)
        images, _ = dataiter.next()

        grid = torchvision.utils.make_grid(images)
        imshow(grid)
        args.writer.add_image('sample-train', grid)

    # Define optimizer
    if args.optimizer == 'adam':
        optimizer = optim.Adam(args.net.parameters(), lr=1e-3)
    elif args.optimizer == 'sgd':
        optimizer = optim.SGD(args.net.parameters(), lr=0.01, momentum=0.9)
    elif args.optimizer == 'rmsprop':
        optimizer = optim.RMSprop(args.net.parameters(), lr=0.01)

    # Loss function
    criterion = elbo_loss_function

    # Set best for minimization
    best = float('inf')

    print('Started Training')
    # loop over the dataset multiple times
    for epoch in range(args.epochs):
        # reset running loss statistics
        train_loss = mse_loss = running_loss = 0.0

        for batch_idx, data in enumerate(train_loader, 1):
            # get the inputs; data is a list of [inputs, labels]
            inputs, _ = data
            inputs = inputs.to(args.device)

            with autograd.detect_anomaly():
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward + backward + optimize
                outputs, mu, logvar = args.net(inputs)
                loss = criterion(outputs, inputs, mu, logvar)
                loss.backward()
                optimizer.step()

            # update running loss statistics
            train_loss += loss.item()
            running_loss += loss.item()
            mse_loss += F.mse_loss(outputs, inputs)

            # Global step
            global_step = batch_idx + len(train_loader) * epoch

            # Write tensorboard statistics
            args.writer.add_scalar('Train/loss', loss.item(), global_step)
            args.writer.add_scalar('Train/mse', F.mse_loss(outputs, inputs),
                                   global_step)

            # check if current batch had best fitness
            if loss.item() < best:
                best = loss.item()
                update_best(inputs, outputs, loss, global_step)

            # print every args.log_interval of batches
            if batch_idx % args.log_interval == 0:
                print("Train Epoch : {} Batches : {} "
                      "[{}/{} ({:.0f}%)]\tLoss : {:.6f}"
                      "\tError : {:.6f}"
                      .format(epoch, batch_idx,
                              args.batch_size * batch_idx,
                              len(train_loader.dataset),
                              100. * batch_idx / len(train_loader),
                              running_loss / args.log_interval,
                              mse_loss / args.log_interval))

                mse_loss = running_loss = 0.0

                # Add images to tensorboard
                write_images_to_tensorboard(inputs, outputs,
                                            global_step, step=True)

        print('====> Epoch: {} Average loss: {:.4f}'.format(
              epoch, train_loss / len(train_loader)))

    # Add trained model
    args.writer.close()
    print('Finished Training')
Пример #32
0
Файл: loss.py Проект: Tahlor/QR
def MSELoss(y_input, y_target):
    return F.mse_loss(y_input, y_target.float())
Пример #33
0
 def forward(self, input):
     G = gram_matrix(input)
     self.loss = F.mse_loss(G, self.target)
     return input
Пример #34
0
    def train(self):
        for epoch in range(self.args.epochs):

            print("[*] Epoch {} starts".format(epoch))
            # self.actor.eval()
            # self.critic.eval()

            # # initial random exploration
            # if(step < self.args.start_steps):
            #     action = self.env.action_space.sample()
            # else:
            #     action = self.generate_action_with_noise(o, self.args.noise_scale)

            episode_reward = 0
            episode = 0
            done = False

            obs_reset = self.env.reset()
            o = obs_reset['observation']

            # action = self.env.action_space.sample()
            # print("Initial action : ",action)
            # # take one step
            # o_next, r, d, _ = self.env.step(action)

            # print(color.BOLD + color.BLUE + "done : " + color.END, d)
            # print(color.BOLD + color.BLUE + "ach goal : " + color.END, o_next['achieved_goal'])
            # print(color.BOLD + color.BLUE + "des goal : " + color.END, o_next['desired_goal'])

            #o_next = o_next['observation']
            #print(color.BOLD + color.BLUE + "obs : " + color.END, o_next)

            # store experience in buffer
            # self.buffer.store(o, o_next, action, r, d)
            # print("Uniques in buffer : ",len(np.unique(self.buffer.obs1_buffer)))
            #
            # episode_reward += r
            # episode_len +=1

            #d = False if episode_len == self.args.max_ep_len else d

            # update observation
            #o=o_next

            while not done:  #and (episode < 300):
                #self.env.render()
                #print ("episode length : {} d : {}".format(episode_len, d))
                #print("---------------------------------------------------------------------")
                #print("[*] Episode {} starts".format(episode))
                self.actor.train()
                self.critic.train()
                self.actor_target.train()
                self.critic_target.train()

                action = self.generate_action_with_noise(
                    o, self.args.noise_scale)
                #print("[*] Action : {} ".format(action))

                obs_nextt, reward, done, _ = self.env.step(action)

                o_next = obs_nextt['observation']
                episode_reward += reward
                #print(color.BOLD + color.BLUE + "obs : " + color.END, o_next)

                # store experience in buffer
                self.buffer.store(o, o_next, action, reward, done)
                #print("Size of buffer : ",self.buffer.size)

                #for _ in range(episode_len):
                # batch size 32 or 100?
                batch = self.buffer.sample_batch()
                (obs, obs_next, actions, rewards,
                 dones) = (torch.Tensor(batch['obs1']),
                           torch.Tensor(batch['obs2']),
                           torch.Tensor(batch['action']),
                           torch.Tensor(batch['reward']),
                           torch.Tensor(batch['done']))

                if self.args.cuda:
                    obs = obs.cuda()
                    obs_next = obs_next.cuda()
                    actions = actions.cuda()
                    rewards = rewards.cuda()

                # deactivating autograd engine to save memory
                #with torch.no_grad():

                action_next = self.actor_target(obs_next)
                #print("[*] Action : {} Action_next : {}".format(action, action_next))

                q_next = self.critic_target(obs_next, action_next).detach()

                bellman_backup = (rewards + self.args.gamma *
                                  (1 - dones) * q_next).detach()
                q_predicted = self.critic(obs, actions)

                #print("[*] q_pred : {} q_targ : {}".format(q_predicted, bellman_backup))
                # calculating critic losses and updating it
                critic_loss = F.mse_loss(q_predicted, bellman_backup)

                # print(color.BLUE + "Critic loss: {}".format(critic_loss) + color.END)
                # print(color.BLUE + "Actor loss: {}".format(actor_loss) + color.END)

                # updating critic (Q) network
                self.critic_optimizer.zero_grad()
                critic_loss.backward()
                self.critic_optimizer.step()

                action = self.actor(obs)
                actor_loss = -self.critic(obs, action).mean()

                # updating actor (policy) network
                self.actor_optimizer.zero_grad()
                actor_loss.backward()
                self.actor_optimizer.step()

                # updating target networks with polyak averaging
                for main_params, target_params in zip(
                        self.actor.parameters(),
                        self.actor_target.parameters()):
                    target_params.data.copy_(
                        self.args.polyak * target_params.data +
                        (1 - self.args.polyak) * main_params.data)

                for main_params, target_params in zip(
                        self.critic.parameters(),
                        self.critic_target.parameters()):
                    target_params.data.copy_(
                        self.args.polyak * target_params.data +
                        (1 - self.args.polyak) * main_params.data)

                #obs_reset, r, d, ep_ret, ep_len = self.env.reset(), 0, False, 0, 0
                o = o_next
                episode += 1

            print("[*] Number of episodes : {} Reward : {}".format(
                episode, episode_reward))
            if done:
                print(color.BOLD + color.BLUE + "ach goal : " + color.END,
                      o_next['achieved_goal'])
                print(color.BOLD + color.BLUE + "des goal : " + color.END,
                      o_next['desired_goal'])
            original = sys.stdout
            with open('helloworld.txt', 'a') as filehandle:
                sys.stdout = filehandle
                print(color.BOLD + color.BLUE + "Critic loss : " + color.END,
                      critic_loss)
                print(color.BOLD + color.BLUE + "Actor loss : " + color.END,
                      actor_loss)
                print(color.BOLD + color.BLUE + "Achieved goal : " + color.END,
                      obs_nextt['achieved_goal'])
                print(color.BOLD + color.BLUE + "Desired goal : " + color.END,
                      obs_nextt['desired_goal'])
                print("[*] Number of episodes : {} Reward : {}".format(
                    episode, episode_reward))
                print("[*] End of epoch ", epoch)
                print(
                    "---------------------------------------------------------------------"
                )
                print()
            sys.stdout = original
            # # Save model
            torch.save(
                self.actor.state_dict(),
                os.path.join(self.args.model_dir,
                             os.path.join(self.args.env_name, "actor.pth")))
            torch.save(
                self.critic.state_dict(),
                os.path.join(self.args.model_dir,
                             os.path.join(self.args.env_name, "critic.pth")))
            torch.save(
                self.actor_target.state_dict(),
                os.path.join(
                    self.args.model_dir,
                    os.path.join(self.args.env_name, "actor_target.pth")))
            torch.save(
                self.critic_target.state_dict(),
                os.path.join(
                    self.args.model_dir,
                    os.path.join(self.args.env_name, "critic_target.pth")))
            # save buffer
            torch.save(
                self.buffer,
                os.path.join(self.args.model_dir,
                             os.path.join(self.args.env_name, "buffer.pth")))
Пример #35
0
def run_dqn(env, save=False):
    """
    Runs the DQN algorithm.
    :param env: Gym environment
    :param save: Set True to save pytorch model of learned weights.
    :return: The learned pytorch model
    """
    FloatTensor = torch.FloatTensor
    LongTensor = torch.LongTensor

    EPISODES = 1
    #EPISODES = 2000
    BATCH_SIZE = 1000
    GAMMA = 0.9
    HIDDEN_LAYER_NEURONS = 300
    LEARNING_RATE = 0.0001
    ACTION_SPACE = 10  # 49
    EPS_START = 1
    EPS_END = 0.01
    #EPS_END = 0.05
    EXPLORATION_STEPS = 1e5

    INITIAL_REPLAY = 100
    REPLAY_SIZE = 1e6
    TARGET_UPDATE = 7000
    global EPSILON
    EPSILON = EPS_START
    EPSILON_STEP = (EPS_START - EPS_END) / EXPLORATION_STEPS

    # define a new discrete action space
    env.action_space = ActionDisc(env.action_space.high, env.action_space.low,
                                  ACTION_SPACE)

    # create the replay buffer and the neural networks
    memory = MemoryDQN(REPLAY_SIZE)
    model = DQN(HIDDEN_LAYER_NEURONS, ACTION_SPACE,
                env.observation_space.shape[0])
    target = DQN(HIDDEN_LAYER_NEURONS, ACTION_SPACE,
                 env.observation_space.shape[0])

    target.load_state_dict(model.state_dict())

    #target.l1.weight = model.l1.weight
    #target.l2.weight = model.l2.weight
    #target.l3.weight = model.l3.weight

    optimizer = optim.Adam(model.parameters(), LEARNING_RATE)

    cum_reward = []

    def select_action(state_pred):
        """
        Epsilon greedy policy
        :param state_pred: Curren state
        :return: Action
        """
        sample = random.random()
        global EPSILON
        epsilon_old = EPSILON
        if EPSILON > EPS_END and memory.size_mem() > INITIAL_REPLAY:
            EPSILON -= EPSILON_STEP
        #epsilon_old = 0.05
        if sample > epsilon_old and memory.size_mem() > INITIAL_REPLAY:
            with torch.no_grad():
                # predict the actions to the given states
                pred_actions = model(state_pred)
                # find the action with the best q-value
                max_action = pred_actions.max(1)[1]
                # return the best action as tensor
                return torch.tensor(
                    np.array([[env.action_space.space[max_action]]]))
        # exploration
        else:
            # return a random action of the action space
            return torch.tensor(np.array([[env.action_space.sample()]]))

    total_steps = 1
    # Start time
    start = datetime.datetime.now()
    for epi in range(EPISODES):
        cum_reward.append(0)
        state = env.reset()
        step = 0
        total_loss = 0

        while True:
            action = select_action(state)

            state_follows, reward, done, info = env.step(action.numpy()[0])

            cum_reward[epi] += reward

            memory.add_observation(state, action, reward, state_follows)

            # if epi == EPISODES - 1:
            #    env.render()

            # training
            if memory.size_mem() > BATCH_SIZE:
                #if memory.size_mem() > INITIAL_REPLAY:

                states, actions, rewards, next_states = \
                    memory.random_batch(BATCH_SIZE)

                # find the index to the given action
                actions = env.action_space.contains(actions)
                # repeat it for the gather method of torch
                actions = np.array(actions).repeat(ACTION_SPACE) \
                    .reshape(BATCH_SIZE, ACTION_SPACE)
                # change it to a long tensor (instead of a float tensor)
                actions = LongTensor(actions)

                # for each q-value(for each state in the batch and for each action)
                # take the one from the chosen action

                current_q_values = model(states)[0].gather(dim=1,
                                                           index=actions)[:, 0]

                # neural net estimates the q-values for the next states
                # take the ones with the highest values
                #max_next_q_values = model(next_states)[0].detach().max(1)[0]
                max_next_q_values = target(next_states)[0].detach().max(1)[0]

                expected_q_values = rewards + (GAMMA * max_next_q_values)

                #loss = F.smooth_l1_loss(current_q_values, expected_q_values.type(FloatTensor))
                loss = F.mse_loss(current_q_values,
                                  expected_q_values.type(FloatTensor))
                total_loss += loss.item()

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_steps += 1
                # update the model weights with the target parameters
                if total_steps % TARGET_UPDATE == 0:
                    total_steps = 1
                    target.load_state_dict(model.state_dict())
                    #target.l1.weight = model.l1.weight
                    #target.l2.weight = model.l2.weight
                    # target.l1.weight = 0.001*model.l1.weight+(1-0.001)*target.l1.weight
                    # target.l2.weight = 0.001*model.l2.weight+(1-0.001)*target.l2.weight
                    #target.l3.weight = model.l3.weight

            state = state_follows
            step += 1

            if done:
                cum_reward[-1] = cum_reward[-1] / step
                break
            '''if step == 500:
                cum_reward[-1]=cum_reward[-1]/500.
                break'''
        # print("Episode:{} Steps:{} Cum.Reward:{} Loss/Step:{} Epsilon:{}"
        #      .format(epi, step, cum_reward[-1], total_loss/step, EPSILON))
    # End time
    end = datetime.datetime.now()
    # print("Learning took", (end-start))
    if save:
        torch.save(model, "model.pt")
    # plt.plot(cum_reward)
    # plt.show()
    return model
Пример #36
0
	def forward(self, fm_s, fm_t):
		loss = F.mse_loss(fm_s, fm_t)

		return loss
Пример #37
0
 def ae_loss(self, reconstruction, target):
   mse = func.mse_loss(reconstruction, target)
   return mse
def L2_soft(outputs, targets):
    softmax_outputs = F.softmax(outputs, dim=1)
    softmax_targets = F.softmax(targets, dim=1)
    return F.mse_loss(softmax_outputs, softmax_targets)
Пример #39
0
    def train(self, obs, next_obs, returns, rewards, masks, actions, values):
        """
        :param obs: [batch_size x height x width x channels] observations in NHWC
        :param next_obs: [batch_size x height x width x channels] one-step next states
        :param returns: [batch_size] n-step discounted returns with bootstrapped value
        :param rewards: [batch_size] 1-step rewards
        :param masks: [batch_size] boolean episode termination mask
        :param actions: [batch_size] actions taken
        :param values: [batch_size] predicted state values
        """

        # compute the sequences we need to get back reward predictions
        action_sequences, reward_sequences, sequence_mask = build_sequences(
            [torch.from_numpy(actions), torch.from_numpy(rewards)], self.nenvs, self.nsteps, self.tree_depth, return_mask=True)
        action_sequences = cudify(action_sequences.long().squeeze(-1))
        reward_sequences = make_variable(reward_sequences.squeeze(-1))
        sequence_mask = make_variable(sequence_mask.squeeze(-1))

        Q, V, tree_result = self.model(obs)

        actions = make_variable(torch.from_numpy(actions).long(), requires_grad=False)
        returns = make_variable(torch.from_numpy(returns), requires_grad=False)

        policy_loss, value_loss, reward_loss, state_loss, subtree_loss_np, policy_entropy = 0, 0, 0, 0, 0, 0
        if self.use_actor_critic:
            values = make_variable(torch.from_numpy(values), requires_grad=False)
            advantages = returns - values
            probs = F.softmax(Q, dim=-1)
            log_probs = F.log_softmax(Q, dim=-1)
            log_probs_taken = log_probs.gather(1, actions.unsqueeze(1)).squeeze()
            pg_loss = -torch.mean(log_probs_taken * advantages.squeeze())
            vf_loss = F.mse_loss(V, returns)
            entropy = -torch.mean(torch.sum(probs * log_probs, 1))
            loss = pg_loss + self.vf_coef * vf_loss - self.ent_coef * entropy

            policy_loss = pg_loss.data.cpu().numpy()
            value_loss = vf_loss.data.cpu().numpy()
            policy_entropy = entropy.data.cpu().numpy()
        else:
            Q_taken = Q.gather(1, actions.unsqueeze(1)).squeeze()
            bellman_loss = F.mse_loss(Q_taken, returns)
            loss = bellman_loss
            value_loss = bellman_loss.data.cpu().numpy()

        if self.use_reward_loss:
            r_taken = get_paths(tree_result["rewards"], action_sequences, self.batch_size, self.num_actions)
            rew_loss = F.mse_loss(torch.cat(r_taken, 1), reward_sequences, reduce=False)
            rew_loss = torch.sum(rew_loss * sequence_mask) / sequence_mask.sum()
            loss = loss + rew_loss * self.rew_loss_coef
            reward_loss = rew_loss.data.cpu().numpy()

        if self.use_st_loss:
            st_embeddings = tree_result["embeddings"][0]
            st_targets, st_mask = build_sequences([st_embeddings.data], self.nenvs, self.nsteps, self.tree_depth, return_mask=True, offset=1)
            st_targets = make_variable(st_targets.view(self.batch_size, -1))
            st_mask = make_variable(st_mask.view(self.batch_size, -1))

            st_taken = get_paths(tree_result["embeddings"][1:], action_sequences, self.batch_size, self.num_actions)

            st_taken_cat = torch.cat(st_taken, 1)

            st_loss = F.mse_loss(st_taken_cat, st_targets, reduce=False)
            st_loss = torch.sum(st_loss * st_mask) / st_mask.sum()

            state_loss = st_loss.data.cpu().numpy()
            loss = loss + st_loss * self.st_loss_coef

        if self.use_subtree_loss:
            subtree_taken = get_subtree(tree_result["values"], action_sequences, self.batch_size, self.num_actions)
            target_subtrees = tree_result["values"][0:-1]
            subtree_taken_clip = time_shift_tree(subtree_taken, self.nenvs, self.nsteps, -1)
            target_subtrees_clip = time_shift_tree(target_subtrees, self.nenvs, self.nsteps, 1)

            subtree_loss = [(s_taken - s_target).pow(2).mean() for (s_taken, s_target) in zip(subtree_taken_clip, target_subtrees_clip)]
            subtree_loss = sum(subtree_loss)
            subtree_loss_np = subtree_loss.data.cpu().numpy()

            loss = loss + subtree_loss * self.subtree_loss_coef

        self.scheduler.step()
        self.optimizer.zero_grad()
        loss.backward()
        grad_norm = nn.utils.clip_grad_norm(self.model.parameters(), self.max_grad_norm)
        self.optimizer.step()

        return policy_loss, value_loss, reward_loss, state_loss, subtree_loss_np, policy_entropy, grad_norm
Пример #40
0
  def ae_loss(self, predictions, target):
    loss = func.mse_loss(predictions, target)

    self.writer.add_scalar("reconstruction loss", float(loss), self.step_id)

    return loss
Пример #41
0
    def train(self, training_samples: TrainingDataPage) -> None:
        if self.minibatch == 0:
            # Assume that the tensors are the right shape after the first minibatch
            max_num_actions = training_samples.possible_next_actions_mask.shape[1]

            assert (
                training_samples.states.shape[0] == self.minibatch_size
            ), "Invalid shape: " + str(training_samples.states.shape)
            assert (
                training_samples.next_states.shape == training_samples.states.shape
            ), "Invalid shape: " + str(training_samples.next_states.shape)
            assert (
                training_samples.not_terminal.shape == training_samples.rewards.shape
            ), "Invalid shape: " + str(training_samples.not_terminal.shape)

            assert (
                training_samples.actions.shape[0] == self.minibatch_size
            ), "Invalid shape: " + str(training_samples.actions.shape)
            assert (
                training_samples.possible_next_actions_mask.shape[0]
                == self.minibatch_size
            ), "Invalid shape: " + str(
                training_samples.possible_next_actions_mask.shape
            )
            assert training_samples.actions.shape[1] == self.num_action_features, (
                "Invalid shape: "
                + str(training_samples.actions.shape[1])
                + " != "
                + str(self.num_action_features)
            )

            assert (
                training_samples.possible_next_actions_state_concat.shape[0]
                == self.minibatch_size * max_num_actions
            ), (
                "Invalid shape: "
                + str(training_samples.possible_next_actions_state_concat.shape)
                + " != "
                + str(self.minibatch_size * max_num_actions)
            )

        self.minibatch += 1

        states = training_samples.states.detach().requires_grad_(True)
        actions = training_samples.actions
        state_action_pairs = torch.cat((states, actions), dim=1)

        rewards = training_samples.rewards
        discount_tensor = torch.full(
            training_samples.time_diffs.shape, self.gamma
        ).type(self.dtype)
        not_done_mask = training_samples.not_terminal

        if self.use_seq_num_diff_as_time_diff:
            discount_tensor = discount_tensor.pow(training_samples.time_diffs)

        if self.maxq_learning:
            all_next_q_values, all_next_q_values_target = self.get_detached_q_values(
                training_samples.possible_next_actions_state_concat
            )
            # Compute max a' Q(s', a') over all possible actions using target network
            next_q_values, _ = self.get_max_q_values_with_target(
                all_next_q_values,
                all_next_q_values_target,
                training_samples.possible_next_actions_mask,
            )
        else:
            # SARSA
            next_q_values, _ = self.get_detached_q_values(
                torch.cat(
                    (training_samples.next_states, training_samples.next_actions), dim=1
                )
            )

        assert next_q_values.shape == not_done_mask.shape, (
            "Invalid shapes: "
            + str(next_q_values.shape)
            + " != "
            + str(not_done_mask.shape)
        )
        filtered_max_q_vals = next_q_values * not_done_mask

        if self.minibatch < self.reward_burnin:
            target_q_values = rewards
        else:
            assert discount_tensor.shape == filtered_max_q_vals.shape, (
                "Invalid shapes: "
                + str(discount_tensor.shape)
                + " != "
                + str(filtered_max_q_vals.shape)
            )
            target_q_values = rewards + (discount_tensor * filtered_max_q_vals)

        # Get Q-value of action taken
        q_values = self.q_network(state_action_pairs)
        all_action_scores = q_values.detach()
        self.model_values_on_logged_actions = q_values.detach()

        value_loss = self.q_network_loss(q_values, target_q_values)
        self.loss = value_loss.detach()

        self.q_network_optimizer.zero_grad()
        value_loss.backward()
        if self.gradient_handler:
            self.gradient_handler(self.q_network.parameters())
        self.q_network_optimizer.step()

        if self.minibatch < self.reward_burnin:
            # Reward burnin: force target network
            self._soft_update(self.q_network, self.q_network_target, 1.0)
        else:
            # Use the soft update rule to update target network
            self._soft_update(self.q_network, self.q_network_target, self.tau)

        # get reward estimates
        reward_estimates = self.reward_network(state_action_pairs)
        reward_loss = F.mse_loss(reward_estimates, rewards)
        self.reward_network_optimizer.zero_grad()
        reward_loss.backward()
        self.reward_network_optimizer.step()

        self.loss_reporter.report(
            td_loss=self.loss,
            reward_loss=reward_loss,
            model_values_on_logged_actions=all_action_scores,
        )
Пример #42
0
def patience_loss(teacher_patience, student_patience, normalized_patience=False):
    if normalized_patience:
        teacher_patience = F.normalize(teacher_patience, p=2, dim=2)
        student_patience = F.normalize(student_patience, p=2, dim=2)
    return F.mse_loss(teacher_patience.float(), student_patience.float()).half()
Пример #43
0
 def compute_loss(self, input, target):
     loss = F.mse_loss(input, target, size_average=False)
     return loss / (2 * self.batch_size)
Пример #44
0
 def forward(self, input,target):
     assert not target.requires_grad
     
     input = clampalpha(input,self.lim1,self.lim2)
     target = clampalpha(target,self.lim1,self.lim2)
     return F.mse_loss(input,target,size_average=self.size_average, reduce=self.reduce)
Пример #45
0
def test_abp_search_exhaustive_global_dynamics(noisy,
                                               clean,
                                               burst_indices,
                                               PS,
                                               NH,
                                               K=-1,
                                               nh_grids=None):

    # -- init vars --
    R, B, N = noisy.shape[:3]
    FMAX = np.finfo(np.float).max
    REF_NH = get_ref_nh(NH)
    print(f"REF_NH: {REF_NH}")
    BI = burst_indices.shape[0]
    ref_patch = noisy[:, :, [N // 2], [REF_NH], :, :, :]

    # -- create clean testing image --
    H = int(np.sqrt(clean.shape[0]))
    clean_img = rearrange(clean[..., N // 2, REF_NH, :, PS // 2, PS // 2],
                          '(h w) b c -> b c h w',
                          h=H)
    clean_img = repeat(clean_img, 'b c h w -> tile b c h w', tile=N)

    # -- create search grids --
    if nh_grids is None: nh_grids = create_nh_grids(BI, NH)
    n_grids = create_n_grids(BI)
    print(f"NH_GRIDS {len(nh_grids)} | N_GRIDS {len(n_grids)}")

    # -- randomly initialize grids --
    # np.random.shuffle(nh_grids)
    # np.random.shuffle(n_grids)

    # -- init loop vars --
    psnrs = np.zeros((len(nh_grids), BI))
    scores = np.zeros(len(nh_grids))
    scores_old = np.zeros(len(nh_grids))
    best_score, best_select = FMAX, None

    # -- remove boundary --
    aug_burst_indices = insert_n_middle(burst_indices, N)
    aug_burst_indices = torch.LongTensor(aug_burst_indices)
    subR = torch.arange(H * H // 3 * 2) + NH * H
    search = noisy[subR]
    ref_patch = ref_patch[subR]

    # -- coordinate descent --
    for nh_index, nh_grid in enumerate(nh_grids):
        # -- compute score --
        grid_patches = search[:, :, burst_indices, nh_grid, :, :, :]
        grid_patches = torch.cat([ref_patch, grid_patches], dim=2)
        score, score_old, count = 0, 0, 0
        for (nset0, nset1) in n_grids[:100]:
            denoised0 = torch.mean(grid_patches[:, :, nset0], dim=2)
            denoised1 = torch.mean(grid_patches[:, :, nset1], dim=2)
            score_old += F.mse_loss(denoised0, denoised1).item()

            # -- neurips 2019 --
            rep0 = repeat(denoised0,
                          'r b c p1 p2 -> r b tile c p1 p2',
                          tile=len(nset0))
            rep01 = repeat(denoised0,
                           'r b c p1 p2 -> r b tile c p1 p2',
                           tile=len(nset1))
            res0 = grid_patches[:, :, nset0] - rep0

            rep1 = repeat(denoised1,
                          'r b c p1 p2 -> r b tile c p1 p2',
                          tile=len(nset1))
            rep10 = repeat(denoised1,
                           'r b c p1 p2 -> r b tile c p1 p2',
                           tile=len(nset0))
            res1 = grid_patches[:, :, nset1] - rep1

            n0, n1 = len(nset0), len(nset1)
            xterms0, xterms1 = np.mgrid[:n0, :n1]
            xterms0, xterms1 = xterms0.ravel(), xterms1.ravel()
            # print(xterms0.shape,xterms1.shape,res0.shape,xterms0.max(),xterms1.max())
            score += F.mse_loss(res0[:, :, xterms0], res1[:, :,
                                                          xterms1]).item()

            # xterms01 = res0 + rep10
            # xterms10 = res1 + rep01

            # score += F.mse_loss(xterms01,xterms10).item()
            # score += F.mse_loss(xterms01,grid_patches[:,:,nset0]).item()
            # score += F.mse_loss(xterms10,grid_patches[:,:,nset1]).item()

            count += 1
        score /= count

        # -- store best score --
        if score < best_score:
            best_score = score
            best_select = nh_grid

        # -- add score to results --
        scores[nh_index] = score
        scores_old[nh_index] = score_old

        # -- compute and store psnrs --
        pgrid = insert_nh_middle(nh_grid, NH, BI)[None, ]
        bgrid = aug_burst_indices
        nh_grid = nh_grid[None, ]
        rec_img = aligned_burst_image_from_indices_global_dynamics(
            clean, burst_indices, nh_grid)  #bgrid,pgrid)
        nh_psnrs = images_to_psnrs(rec_img, clean_img[burst_indices])
        psnrs[nh_index, :] = nh_psnrs

    score_idx = np.argmin(scores)
    print(f"Best Score [{scores[score_idx]}] PSNRS @ [{score_idx}]:",
          psnrs[score_idx])

    psnr_idx = np.argmax(np.mean(psnrs, 1))
    print(f"Best PSNR @ [{psnr_idx}]", psnrs[psnr_idx])
    # print(scores[score_idx] - scores[psnr_idx])

    old_score_idx = np.argmin(scores_old)
    print(
        f"Best OLD Score [{scores_old[old_score_idx]}] PSNRS @ [{old_score_idx}]:",
        psnrs[old_score_idx])
    print(f"Current Score @ OLD Score [{scores[old_score_idx]}]")
    print(
        f"[Old score idx v.s. Current score idx v.s. Best PSNR] {old_score_idx} v.s. {score_idx} v.s. {psnr_idx}"
    )

    #
    #  Recording Score Info
    #

    # -- save score info --
    scores /= np.sum(scores)
    score_fn = f"scores_{NH}_{N}_{len(nh_grids)}_{len(n_grids)}"
    txt_fn = Path(f"output/abps/{score_fn}.txt")
    np.savetxt(txt_fn, scores)

    # -- plot score --
    plot_fn = Path(f"output/abps/{score_fn}.png")
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.plot(np.arange(scores.shape[0]), scores, '-+')
    ax.axvline(x=psnr_idx, color="r")
    ax.axvline(x=score_idx, color="k")
    plt.savefig(plot_fn, dpi=300)
    plt.close("all")

    #
    #  Recording PSNR Info
    #

    # -- save score info --
    psnr_fn = f"psnrs_{NH}_{N}_{len(nh_grids)}_{len(n_grids)}"
    txt_fn = Path(f"output/abps/{psnr_fn}.txt")
    np.savetxt(txt_fn, psnrs)

    # -- plot psnr --
    plot_fn = Path(f"output/abps/{psnr_fn}.png")
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.plot(np.arange(psnrs.shape[0]), psnrs, '-+')
    ax.axvline(x=psnr_idx, color="r")
    ax.axvline(x=score_idx, color="k")
    plt.savefig(plot_fn, dpi=300)
    plt.close("all")

    print(f"Wrote {score_fn} and {psnr_fn}")

    if K == -1:
        return best_score, best_select
    else:
        search_indices_topK = np.argsort(scores)[:K]
        scores_topK = scores[search_indices_topK]
        nh_grids_topK = nh_grids[search_indices_topK]
        return scores_topK, nh_grids_topK
Пример #46
0
def test(
    model_G: nn.Module,
    model_D: nn.Module,
    encoder: nn.Module,
    test_loader: DataLoader,
    use_cuda: bool = False,
):
    device = "cuda" if th.cuda.is_available() and use_cuda else "cpu"

    model_G.to(device)
    model_D.to(device)
    model_G.eval()
    model_D.eval()
    encoder.eval()

    batches = len(test_loader)
    D_loss_sum = 0
    G_loss_sum = 0
    loss_sum = 0
    D_real_prob = 0
    D_fake_prob = 0
    reconstruction_loss = 0

    with th.set_grad_enabled(False):
        for data, egg_data in test_loader:
            data, egg_data = data.to(device), egg_data.to(device)
            data.requires_grad_, egg_data.requires_grad_ = False, False

            batch_size = data.shape[0]
            ones_label = th.ones(batch_size, 1).to(device)
            zeros_label = th.zeros(batch_size, 1).to(device)

            # Test model_D
            true = encoder(egg_data)
            _, embeddings_ = model_G(data)
            D_real = model_D(true)
            D_fake = model_D(embeddings_)

            D_loss_real = F.binary_cross_entropy(D_real, ones_label)
            D_loss_fake = F.binary_cross_entropy(D_fake, zeros_label)
            D_loss = D_loss_real + D_loss_fake

            D_real_prob += D_real.mean().item()
            D_fake_prob += D_fake.mean().item()

            # Test model_G
            reconstructions, embeddings_ = model_G(data)
            D_fake = model_D(embeddings_)

            # loss_reconstruction =             (egg_data * reconstructions).sum(dim=1) / (egg_data.norm(dim=1) * reconstructions.norm(dim=1))
            # loss_reconstruction = th.acos(
            #     loss_reconstruction
            # ) * 180 / np.pi +
            # F.mse_loss(
            #     egg_data, reconstructions, reduction="none"
            # ).sum(dim=1)
            # loss_reconstruction = loss_reconstruction.mean()

            loss_reconstruction = F.mse_loss(egg_data, reconstructions)

            G_loss = F.binary_cross_entropy(D_fake, ones_label)

            net_loss = loss_reconstruction + G_loss

            loss_sum += net_loss.item()
            reconstruction_loss += loss_reconstruction.item()

            D_loss_sum += D_loss
            G_loss_sum += G_loss

    del D_loss, G_loss
    th.cuda.empty_cache()

    print(
        "\nTest set: Test loss {:4.4} D_loss {:4.4} G_loss {:4.4} reconstruction loss {:4.4} Real D prob. {:4.4} Fake D prob. {:4.4}\n"
        .format(
            loss_sum / batches,
            D_loss_sum / batches,
            G_loss_sum / batches,
            reconstruction_loss / batches,
            D_real_prob / batches,
            D_fake_prob / batches,
        ))
Пример #47
0
 def forward(self, input):
     self.loss = F.mse_loss(input, self.target)
     return input
Пример #48
0
def train(
    model_G: nn.Module,
    model_D: nn.Module,
    encoder: nn.Module,
    optimizer_G: optim.Optimizer,
    optimizer_R: optim.Optimizer,
    optimizer_D: optim.Optimizer,
    train_data: DataLoader,
    use_cuda: bool = True,
    scheduler_G=None,
    scheduler_R=None,
    scheduler_D=None,
):
    device = "cuda" if th.cuda.is_available() and use_cuda else "cpu"

    model_G.train()
    model_D.train()
    encoder.eval()

    batches = len(train_data)
    D_loss_sum = 0
    D_real_prob = 0
    D_fake_prob = 0
    G_loss_sum = 0
    loss_sum = 0
    reconstruction_loss = 0

    model_G.to(device)
    model_D.to(device)
    encoder.to(device)

    for data, egg_data in train_data:
        if scheduler_G is not None:
            scheduler_G.step()
            scheduler_R.step()
            scheduler_D.step()

        data, egg_data = data.to(device), egg_data.to(device)

        optimizer_G.zero_grad()
        optimizer_D.zero_grad()
        optimizer_R.zero_grad()

        batch_size = data.shape[0]
        ones_label = th.ones(batch_size, 1).to(device)
        zeros_label = th.zeros(batch_size, 1).to(device)

        # Optimize model_D
        true = encoder(egg_data)
        _, embeddings_ = model_G(data)
        D_real = model_D(true)
        D_fake = model_D(embeddings_)

        D_loss_real = F.binary_cross_entropy(D_real, ones_label)
        D_loss_fake = F.binary_cross_entropy(D_fake, zeros_label)
        D_loss = D_loss_real + D_loss_fake

        D_real_prob += D_real.mean().item()
        D_fake_prob += D_fake.mean().item()

        # for i in model_G.parameters():
        #     i.requires_grad = False
        # for i in model_D.parameters():
        #     i.requires_grad = True
        D_loss.backward()
        optimizer_D.step()
        optimizer_G.zero_grad()
        optimizer_D.zero_grad()
        optimizer_R.zero_grad()

        # Optimize model_G
        # for i in model_G.parameters():
        #     i.requires_grad = True
        # for i in model_D.parameters():
        #     i.requires_grad = False
        _, embeddings_ = model_G(data)
        D_fake = model_D(embeddings_)

        G_loss = F.binary_cross_entropy(D_fake, ones_label)
        G_loss.backward()
        optimizer_G.step()
        optimizer_G.zero_grad()
        optimizer_D.zero_grad()
        optimizer_R.zero_grad()

        reconstructions, _ = model_G(data)
        # loss_reconstruction = (egg_data * reconstructions).sum(dim=1) / (
        #     egg_data.norm(dim=1) * reconstructions.norm(dim=1)
        # )
        # loss_reconstruction = th.acos(loss_reconstruction) * 180 / np.pi
        # loss_reconstruction = loss_reconstruction.mean()

        loss_reconstruction = F.mse_loss(egg_data, reconstructions)
        loss_reconstruction.backward()
        optimizer_R.step()
        optimizer_G.zero_grad()
        optimizer_D.zero_grad()
        optimizer_R.zero_grad()

        net_loss = loss_reconstruction + G_loss

        loss_sum += net_loss.item()
        reconstruction_loss += loss_reconstruction.item()

        D_loss_sum += D_loss
        G_loss_sum += G_loss

    del D_loss, G_loss
    th.cuda.empty_cache()

    return (
        loss_sum / batches,
        D_loss_sum / batches,
        G_loss_sum / batches,
        reconstruction_loss / batches,
        D_real_prob / batches,
        D_fake_prob / batches,
    )
Пример #49
0
 def forward(self, input):
     self.loss = F.mse_loss(input, self.target)
     return input
Пример #50
0
def soc_adaptation_iter(modnet,
                        backup_modnet,
                        optimizer,
                        image,
                        soc_semantic_scale=100.0,
                        soc_detail_scale=1.0):
    """ Self-Supervised sub-objective consistency (SOC) adaptation iteration of MODNet
    This function fine-tunes MODNet for one iteration in an unlabeled dataset.
    Note that SOC can only fine-tune a converged MODNet, i.e., MODNet that has been 
    trained in a labeled dataset.

    Arguments:
        modnet (torch.nn.Module): instance of MODNet
        backup_modnet (torch.nn.Module): backup of the trained MODNet
        optimizer (torch.optim.Optimizer): optimizer for self-supervised SOC 
        image (torch.autograd.Variable): input RGB image
                                         its pixel values should be normalized
        soc_semantic_scale (float): scale of the SOC semantic loss 
                                    NOTE: please adjust according to your dataset
        soc_detail_scale (float): scale of the SOC detail loss
                                  NOTE: please adjust according to your dataset
    
    Returns:
        soc_semantic_loss (torch.Tensor): loss of the semantic SOC
        soc_detail_loss (torch.Tensor): loss of the detail SOC

    Example:
        import copy
        import torch
        from src.models.modnet import MODNet
        from src.trainer import soc_adaptation_iter

        bs = 1          # batch size
        lr = 0.00001    # learn rate
        epochs = 10     # total epochs

        modnet = torch.nn.DataParallel(MODNet()).cuda()
        modnet = LOAD_TRAINED_CKPT()    # NOTE: please finish this function

        optimizer = torch.optim.Adam(modnet.parameters(), lr=lr, betas=(0.9, 0.99))
        dataloader = CREATE_YOUR_DATALOADER(bs)     # NOTE: please finish this function

        for epoch in range(0, epochs):
            backup_modnet = copy.deepcopy(modnet)
            for idx, (image) in enumerate(dataloader):
                soc_semantic_loss, soc_detail_loss = \
                    soc_adaptation_iter(modnet, backup_modnet, optimizer, image)
    """

    global blurer

    # set the backup model to eval mode
    backup_modnet.eval()

    # set the main model to train mode and freeze its norm layers
    modnet.train()
    modnet.module.freeze_norm()

    # clear the optimizer
    optimizer.zero_grad()

    # forward the main model
    pred_semantic, pred_detail, pred_matte = modnet(image, False)

    # forward the backup model
    with torch.no_grad():
        _, pred_backup_detail, pred_backup_matte = backup_modnet(image, False)

    # calculate the boundary mask from `pred_matte` and `pred_semantic`
    pred_matte_fg = (pred_matte.detach() > 0.1).float()
    pred_semantic_fg = (pred_semantic.detach() > 0.1).float()
    pred_semantic_fg = F.interpolate(pred_semantic_fg,
                                     scale_factor=16,
                                     mode='bilinear')
    pred_fg = pred_matte_fg * pred_semantic_fg

    n, c, h, w = pred_matte.shape
    np_pred_fg = pred_fg.data.cpu().numpy()
    np_boundaries = np.zeros([n, c, h, w])
    for sdx in range(0, n):
        sample_np_boundaries = np_boundaries[sdx, 0, ...]
        sample_np_pred_fg = np_pred_fg[sdx, 0, ...]

        side = int((h + w) / 2 * 0.05)
        dilated = grey_dilation(sample_np_pred_fg, size=(side, side))
        eroded = grey_erosion(sample_np_pred_fg, size=(side, side))

        sample_np_boundaries[np.where(dilated - eroded != 0)] = 1
        np_boundaries[sdx, 0, ...] = sample_np_boundaries

    boundaries = torch.tensor(np_boundaries).float().cuda()

    # sub-objectives consistency between `pred_semantic` and `pred_matte`
    # generate pseudo ground truth for `pred_semantic`
    downsampled_pred_matte = blurer(
        F.interpolate(pred_matte, scale_factor=1 / 16, mode='bilinear'))
    pseudo_gt_semantic = downsampled_pred_matte.detach()
    pseudo_gt_semantic = pseudo_gt_semantic * (pseudo_gt_semantic >
                                               0.01).float()

    # generate pseudo ground truth for `pred_matte`
    pseudo_gt_matte = pred_semantic.detach()
    pseudo_gt_matte = pseudo_gt_matte * (pseudo_gt_matte > 0.01).float()

    # calculate the SOC semantic loss
    soc_semantic_loss = F.mse_loss(pred_semantic,
                                   pseudo_gt_semantic) + F.mse_loss(
                                       downsampled_pred_matte, pseudo_gt_matte)
    soc_semantic_loss = soc_semantic_scale * torch.mean(soc_semantic_loss)

    # NOTE: using the formulas in our paper to calculate the following losses has similar results
    # sub-objectives consistency between `pred_detail` and `pred_backup_detail` (on boundaries only)
    backup_detail_loss = boundaries * F.l1_loss(
        pred_detail, pred_backup_detail, reduction='none')
    backup_detail_loss = torch.sum(backup_detail_loss, dim=(
        1, 2, 3)) / torch.sum(boundaries, dim=(1, 2, 3))
    backup_detail_loss = torch.mean(backup_detail_loss)

    # sub-objectives consistency between pred_matte` and `pred_backup_matte` (on boundaries only)
    backup_matte_loss = boundaries * F.l1_loss(
        pred_matte, pred_backup_matte, reduction='none')
    backup_matte_loss = torch.sum(backup_matte_loss, dim=(
        1, 2, 3)) / torch.sum(boundaries, dim=(1, 2, 3))
    backup_matte_loss = torch.mean(backup_matte_loss)

    soc_detail_loss = soc_detail_scale * (backup_detail_loss +
                                          backup_matte_loss)

    # calculate the final loss, backward the loss, and update the model
    loss = soc_semantic_loss + soc_detail_loss

    loss.backward()
    optimizer.step()

    return soc_semantic_loss, soc_detail_loss
Пример #51
0
def train(**kwargs):
    opt = Config()
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    vis = utils.Visualizer(opt.env)

    # 数据加载
    transfroms = tv.transforms.Compose([
        tv.transforms.Scale(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Lambda(lambda x: x * 255)
    ])
    dataset = tv.datasets.ImageFolder(opt.data_root, transfroms)
    dataloader = data.DataLoader(dataset, opt.batch_size)

    # 转换网络
    transformer = TransformerNet()
    if opt.model_path:
        transformer.load_state_dict(t.load(opt.model_path, map_location=lambda _s, _: _s))

    # 损失网络 Vgg16
    vgg = Vgg16().eval()

    # 优化器
    optimizer = t.optim.Adam(transformer.parameters(), opt.lr)

    # 获取风格图片的数据
    style = utils.get_style_data(opt.style_path)
    vis.img('style', (style[0] * 0.225 + 0.45).clamp(min=0, max=1))

    if opt.use_gpu:
        transformer.cuda()
        style = style.cuda()
        vgg.cuda()

    # 风格图片的gram矩阵
    style_v = Variable(style, volatile=True)
    features_style = vgg(style_v)
    gram_style = [Variable(utils.gram_matrix(y.data)) for y in features_style]

    # 损失统计
    style_meter = tnt.meter.AverageValueMeter()
    content_meter = tnt.meter.AverageValueMeter()

    for epoch in range(opt.epoches):
        content_meter.reset()
        style_meter.reset()

        for ii, (x, _) in tqdm.tqdm(enumerate(dataloader)):

            # 训练
            optimizer.zero_grad()
            if opt.use_gpu:
                x = x.cuda()
            x = Variable(x)
            y = transformer(x)
            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)
            features_y = vgg(y)
            features_x = vgg(x)

            # content loss
            content_loss = opt.content_weight * F.mse_loss(features_y.relu2_2, features_x.relu2_2)

            # style loss
            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gram_y = utils.gram_matrix(ft_y)
                style_loss += F.mse_loss(gram_y, gm_s.expand_as(gram_y))
            style_loss *= opt.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            # 损失平滑
            content_meter.add(content_loss.data[0])
            style_meter.add(style_loss.data[0])

            if (ii + 1) % opt.plot_every == 0:
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()

                # 可视化
                vis.plot('content_loss', content_meter.value()[0])
                vis.plot('style_loss', style_meter.value()[0])
                # 因为x和y经过标准化处理(utils.normalize_batch),所以需要将它们还原
                vis.img('output', (y.data.cpu()[0] * 0.225 + 0.45).clamp(min=0, max=1))
                vis.img('input', (x.data.cpu()[0] * 0.225 + 0.45).clamp(min=0, max=1))

        # 保存visdom和模型
        vis.save([opt.env])
        t.save(transformer.state_dict(), 'checkpoints/%s_style.pth' % epoch)
Пример #52
0
def supervised_training_iter(modnet,
                             optimizer,
                             image,
                             trimap,
                             gt_matte,
                             semantic_scale=10.0,
                             detail_scale=10.0,
                             matte_scale=1.0):
    """ Supervised training iteration of MODNet
    This function trains MODNet for one iteration in a labeled dataset.

    Arguments:
        modnet (torch.nn.Module): instance of MODNet
        optimizer (torch.optim.Optimizer): optimizer for supervised training 
        image (torch.autograd.Variable): input RGB image
                                         its pixel values should be normalized
        trimap (torch.autograd.Variable): trimap used to calculate the losses
                                          its pixel values can be 0, 0.5, or 1
                                          (foreground=1, background=0, unknown=0.5)
        gt_matte (torch.autograd.Variable): ground truth alpha matte
                                            its pixel values are between [0, 1]
        semantic_scale (float): scale of the semantic loss
                                NOTE: please adjust according to your dataset
        detail_scale (float): scale of the detail loss
                              NOTE: please adjust according to your dataset
        matte_scale (float): scale of the matte loss
                             NOTE: please adjust according to your dataset
    
    Returns:
        semantic_loss (torch.Tensor): loss of the semantic estimation [Low-Resolution (LR) Branch]
        detail_loss (torch.Tensor): loss of the detail prediction [High-Resolution (HR) Branch]
        matte_loss (torch.Tensor): loss of the semantic-detail fusion [Fusion Branch]

    Example:
        import torch
        from src.models.modnet import MODNet
        from src.trainer import supervised_training_iter

        bs = 16         # batch size
        lr = 0.01       # learn rate
        epochs = 40     # total epochs

        modnet = torch.nn.DataParallel(MODNet()).cuda()
        optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1)

        dataloader = CREATE_YOUR_DATALOADER(bs)     # NOTE: please finish this function

        for epoch in range(0, epochs):
            for idx, (image, trimap, gt_matte) in enumerate(dataloader):
                semantic_loss, detail_loss, matte_loss = \
                    supervised_training_iter(modnet, optimizer, image, trimap, gt_matte)
            lr_scheduler.step()
    """

    global blurer

    # set the model to train mode and clear the optimizer
    modnet.train()
    optimizer.zero_grad()

    # forward the model
    pred_semantic, pred_detail, pred_matte = modnet(image, False)

    # calculate the boundary mask from the trimap
    boundaries = (trimap < 0.5) + (trimap > 0.5)

    # calculate the semantic loss
    gt_semantic = F.interpolate(gt_matte, scale_factor=1 / 16, mode='bilinear')
    gt_semantic = blurer(gt_semantic)
    semantic_loss = torch.mean(F.mse_loss(pred_semantic, gt_semantic))
    semantic_loss = semantic_scale * semantic_loss

    # calculate the detail loss
    pred_boundary_detail = torch.where(boundaries, trimap, pred_detail)
    gt_detail = torch.where(boundaries, trimap, gt_matte)
    detail_loss = torch.mean(F.l1_loss(pred_boundary_detail, gt_detail))
    detail_loss = detail_scale * detail_loss

    # calculate the matte loss
    pred_boundary_matte = torch.where(boundaries, trimap, pred_matte)
    matte_l1_loss = F.l1_loss(
        pred_matte, gt_matte) + 4.0 * F.l1_loss(pred_boundary_matte, gt_matte)
    matte_compositional_loss = F.l1_loss(image * pred_matte, image * gt_matte) \
        + 4.0 * F.l1_loss(image * pred_boundary_matte, image * gt_matte)
    matte_loss = torch.mean(matte_l1_loss + matte_compositional_loss)
    matte_loss = matte_scale * matte_loss

    # calculate the final loss, backward the loss, and update the model
    loss = semantic_loss + detail_loss + matte_loss
    loss.backward()
    optimizer.step()

    # for test
    return semantic_loss, detail_loss, matte_loss
            rest = int(np.prod(sequence_batch.size()[2:]))
            flat_sequence_batch = sequence_batch.view(batch_size * sequence_len, rest)
            # break up this potentially large batch into nicer small ones for gsn
            for batch_idx in range(int(flat_sequence_batch.size()[0] / 32)):
                x = flat_sequence_batch[batch_idx * 32:(batch_idx + 1) * 32]
                # train the gsn!
                gsn_optimizer.zero_grad()
                regression_optimizer.zero_grad()
                recons, _, _ = model.gsn(x)
                losses = [F.binary_cross_entropy(input=recon, target=x) for recon in recons]
                loss = sum(losses)
                loss.backward()
                torch.nn.utils.clip_grad_norm(model.parameters(), .25)
                gsn_optimizer.step()
                gsn_train_losses.append(losses[-1].data.cpu().numpy()[0])
                accuracies = [F.mse_loss(input=recon, target=x) for recon in recons]
                gsn_train_accuracies.append(np.mean([acc.data.cpu().numpy() for acc in accuracies]))

        print("GSN Train Loss", np.mean(gsn_train_losses))
        print("GSN Train Accuracy", np.mean(gsn_train_accuracies))
        print("GSN Train time", make_time_units_string(time.time() - gsn_start_time))

        ####
        # train the regression step
        ####
        regression_train_losses = []
        regression_train_accuracies = []
        regression_train_accuracies2 = []
        regression_start = time.time()
        for batch_idx, sequence_batch in enumerate(train_loader):
            sequence_batch = Variable(sequence_batch, requires_grad=False)
Пример #54
0
    def train_step_gen(self, training_batch: rlt.PolicyNetworkInput,
                       batch_idx: int):
        """
        IMPORTANT: the input action here is assumed to be preprocessed to match the
        range of the output of the actor.
        """
        assert isinstance(training_batch, rlt.PolicyNetworkInput)

        state = training_batch.state
        action = training_batch.action
        next_state = training_batch.next_state
        reward = training_batch.reward
        not_terminal = training_batch.not_terminal

        # Generate target = r + y * min (Q1(s',pi(s')), Q2(s',pi(s')))
        with torch.no_grad():
            next_actor = self.actor_network_target(next_state).action
            noise = torch.randn_like(next_actor) * self.noise_variance
            next_actor = (next_actor +
                          noise.clamp(*self.noise_clip_range)).clamp(
                              *CONTINUOUS_TRAINING_ACTION_RANGE)
            next_state_actor = (next_state, rlt.FeatureData(next_actor))
            next_q_value = self.q1_network_target(*next_state_actor)

            if self.q2_network is not None:
                next_q_value = torch.min(
                    next_q_value, self.q2_network_target(*next_state_actor))

            target_q_value = reward + self.gamma * next_q_value * not_terminal.float(
            )

        # Optimize Q1 and Q2
        q1_value = self.q1_network(state, action)
        q1_loss = F.mse_loss(q1_value, target_q_value)
        if batch_idx % self.trainer.log_every_n_steps == 0:
            self.reporter.log(
                q1_loss=q1_loss,
                q1_value=q1_value,
                next_q_value=next_q_value,
                target_q_value=target_q_value,
            )
        self.log("td_loss", q1_loss, prog_bar=True)
        yield q1_loss

        if self.q2_network:
            q2_value = self.q2_network(state, action)
            q2_loss = F.mse_loss(q2_value, target_q_value)
            if batch_idx % self.trainer.log_every_n_steps == 0:
                self.reporter.log(
                    q2_loss=q2_loss,
                    q2_value=q2_value,
                )
            yield q2_loss

        # Only update actor and target networks after a fixed number of Q updates
        if batch_idx % self.delayed_policy_update == 0:
            actor_action = self.actor_network(state).action
            actor_q1_value = self.q1_network(state,
                                             rlt.FeatureData(actor_action))
            actor_loss = -(actor_q1_value.mean())
            if batch_idx % self.trainer.log_every_n_steps == 0:
                self.reporter.log(
                    actor_loss=actor_loss,
                    actor_q1_value=actor_q1_value,
                )
            yield actor_loss

            # Use the soft update rule to update the target networks
            result = self.soft_update_result()
            yield result

        else:
            # Yielding None prevents the actor and target networks from updating
            yield None
            yield None
        p.data.zero_()

    optimizer = optim.Adam(net.parameters(), lr=1e-2)
    iter_idx = 0

    while True:
        iter_idx += 1
        sum_loss = 0.0
        for v in TRAIN_DATA:
            tgt_net.sync()
            x_v = Variable(torch.from_numpy(np.array([v], dtype=np.float32)))
            y_v = Variable(torch.from_numpy(np.array([get_y(v)], dtype=np.float32)))

            tgt_net.target_model.zero_grad()
            out_v = tgt_net.target_model(x_v)
            loss_v = F.mse_loss(out_v, y_v)
            loss_v.backward()
            grads = [param.grad.data.cpu().numpy() if param.grad is not None else None
                     for param in tgt_net.target_model.parameters()]

            # apply gradients
            for grad, param in zip(grads, net.parameters()):
                param.grad = Variable(torch.from_numpy(grad))

            optimizer.step()
            sum_loss += loss_v.data.cpu().numpy()
        print("%d: %.2f" % (iter_idx, sum_loss))
        if sum_loss < 0.1:
            break

    pass
Пример #56
0
def train_loop(cfg, model, optimizer, criterion, train_loader, epoch,
               record_losses):

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #    Setup for epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-

    model.train()
    model = model.to(cfg.device)
    N = cfg.N
    total_loss = 0
    running_loss = 0
    szm = ScaleZeroMean()
    blocksize = 128
    unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize)
    use_record = False
    if record_losses is None:
        record_losses = pd.DataFrame({
            'burst': [],
            'ave': [],
            'ot': [],
            'psnr': [],
            'psnr_std': []
        })

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Init Record Keeping
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    align_mse_losses, align_mse_count = 0, 0
    rec_mse_losses, rec_mse_count = 0, 0
    rec_ot_losses, rec_ot_count = 0, 0
    running_loss, total_loss = 0, 0

    write_examples = True
    noise_level = cfg.noise_params['g']['stddev']

    # -=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #    Add hooks for epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-

    align_hook = AlignmentFilterHooks(cfg.N)
    align_hooks = []
    for kpn_module in model.kpn.children():
        for name, layer in kpn_module.named_children():
            if name == "filter_cls":
                align_hook_handle = layer.register_forward_hook(align_hook)
                align_hooks.append(align_hook_handle)

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Init Loss Functions
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    alignmentLossMSE = BurstRecLoss()
    denoiseLossMSE = BurstRecLoss()
    # denoiseLossOT = BurstResidualLoss()
    entropyLoss = EntropyLoss()

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #    Final Configs
    #
    # -=-=-=-=-=-=-=-=-=-=-

    use_timer = False
    one = torch.FloatTensor([1.]).to(cfg.device)
    switch = True
    if use_timer: clock = Timer()
    train_iter = iter(train_loader)
    steps_per_epoch = len(train_loader)
    write_examples_iter = steps_per_epoch // 3

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #     Start Epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-

    for batch_idx in range(steps_per_epoch):

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #      Setting up for Iteration
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- setup iteration timer --
        if use_timer: clock.tic()

        # -- zero gradients; ready 2 go --
        optimizer.zero_grad()
        model.zero_grad()
        model.denoiser_info.optim.zero_grad()

        # -- grab data batch --
        burst, res_imgs, raw_img, directions = next(train_iter)

        # -- getting shapes of data --
        N, B, C, H, W = burst.shape
        burst = burst.cuda(non_blocking=True)

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #      Formatting Images for FP
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- creating some transforms --
        stacked_burst = rearrange(burst, 'n b c h w -> b n c h w')
        cat_burst = rearrange(burst, 'n b c h w -> (b n) c h w')

        # -- extract target image --
        mid_img = burst[N // 2]
        raw_zm_img = szm(raw_img.cuda(non_blocking=True))
        if cfg.supervised: gt_img = raw_zm_img
        else: gt_img = mid_img

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #           Foward Pass
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        outputs = model(burst)
        aligned, aligned_ave, denoised, denoised_ave = outputs[:4]
        aligned_filters, denoised_filters = outputs[4:]

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #   Require Approx Equal Filter Norms
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        denoised_filters = rearrange(denoised_filters.detach(),
                                     'b n k2 c h w -> n (b k2 c h w)')
        norms = denoised_filters.norm(dim=1)
        norm_loss_denoiser = torch.mean((norms - norms[N // 2])**2)
        norm_loss_coeff = 1000.

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #    Decrease Entropy within a Kernel
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        filters_entropy = 0
        filters_entropy_coeff = 10.
        all_filters = []
        L = len(align_hook.filters)
        iter_filters = align_hook.filters if L > 0 else [aligned_filters]
        for filters in iter_filters:
            filters_shaped = rearrange(filters,
                                       'b n k2 c h w -> (b n c h w) k2',
                                       n=N)
            filters_entropy += entropyLoss(filters_shaped)
            all_filters.append(filters)
        if L > 0: filters_entropy /= L
        all_filters = torch.stack(all_filters, dim=1)
        align_hook.clear()

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #    Increase Entropy across each Kernel
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        filters_dist_entropy = 0

        # -- across each frame --
        # filters_shaped = rearrange(all_filters,'b l n k2 c h w -> (b l) (n c h w) k2')
        # filters_shaped = torch.mean(filters_shaped,dim=1)
        # filters_dist_entropy += -1 * entropyLoss(filters_shaped)

        # -- across each batch --
        filters_shaped = rearrange(all_filters,
                                   'b l n k2 c h w -> (n l) (b c h w) k2')
        filters_shaped = torch.mean(filters_shaped, dim=1)
        filters_dist_entropy += -1 * entropyLoss(filters_shaped)

        # -- across each kpn cascade --
        # filters_shaped = rearrange(all_filters,'b l n k2 c h w -> (b n) (l c h w) k2')
        # filters_shaped = torch.mean(filters_shaped,dim=1)
        # filters_dist_entropy += -1 * entropyLoss(filters_shaped)

        filters_dist_coeff = 0

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #    Alignment Losses (MSE)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        losses = alignmentLossMSE(aligned, aligned_ave, gt_img,
                                  cfg.global_step)
        ave_loss, burst_loss = [loss.item() for loss in losses]
        align_mse = np.sum(losses)
        align_mse_coeff = 1.

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #   Reconstruction Losses (MSE)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        denoised_ave_d = denoised_ave.detach()
        losses = denoiseLossMSE(denoised, denoised_ave, gt_img,
                                cfg.global_step)
        ave_loss, burst_loss = [loss.item() for loss in losses]
        rec_mse = np.sum(losses)
        rec_mse_coeff = 0.95**cfg.global_step

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #    Reconstruction Losses (Distribution)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- regularization scheduler --
        if cfg.global_step < 100: reg = 0.5
        elif cfg.global_step < 200: reg = 0.25
        elif cfg.global_step < 5000: reg = 0.15
        elif cfg.global_step < 10000: reg = 0.1
        else: reg = 0.05

        # -- computation --
        residuals = denoised - gt_img.unsqueeze(1).repeat(1, N, 1, 1, 1)
        # residuals = rearrange(residuals,'b n c h w -> b n (h w) c')
        # rec_ot_pair_loss_v1 = w_gaussian_bp(residuals,noise_level)
        # rec_ot_pair_loss_v1 = kl_gaussian_bp_patches(residuals,noise_level,flip=True,patchsize=16)
        rec_ot_pair_loss_v1 = kl_gaussian_bp(residuals, noise_level, flip=True)
        # rec_ot_pair_loss_v1 = ot_pairwise2gaussian_bp(residuals,K=6,reg=reg)
        # rec_ot_pair_loss_v2 = ot_pairwise_bp(residuals,K=3)
        rec_ot_pair_loss_v2 = torch.FloatTensor([0.]).to(cfg.device)
        rec_ot_pair = (rec_ot_pair_loss_v1 + rec_ot_pair_loss_v2) / 2.
        rec_ot_pair_coeff = 100.  # - .997**cfg.global_step

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #              Final Losses
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        norm_loss = norm_loss_coeff * norm_loss_denoiser
        align_loss = align_mse_coeff * align_mse
        rec_loss = rec_ot_pair_coeff * rec_ot_pair + rec_mse_coeff * rec_mse
        entropy_loss = filters_entropy_coeff * filters_entropy + filters_dist_coeff * filters_dist_entropy
        final_loss = rec_loss * align_loss + rec_loss + entropy_loss + norm_loss

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #              Record Keeping
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- alignment MSE --
        align_mse_losses += align_mse.item()
        align_mse_count += 1

        # -- reconstruction MSE --
        rec_mse_losses += rec_mse.item()
        rec_mse_count += 1

        # -- reconstruction Dist. --
        rec_ot_losses += rec_ot_pair.item()
        rec_ot_count += 1

        # -- total loss --
        running_loss += final_loss.item()
        total_loss += final_loss.item()

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #        Gradients & Backpropogration
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- compute the gradients! --
        final_loss.backward()

        # -- backprop now. --
        model.denoiser_info.optim.step()
        optimizer.step()

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #            Printing to Stdout
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0:

            # -- compute mse for fun --
            B = raw_img.shape[0]
            raw_img = raw_img.cuda(non_blocking=True)

            # -- psnr for [average of aligned frames] --
            mse_loss = F.mse_loss(raw_img, aligned_ave + 0.5,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_aligned_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [average of input, misaligned frames] --
            mis_ave = torch.mean(stacked_burst, dim=1)
            mse_loss = F.mse_loss(raw_img, mis_ave + 0.5,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_misaligned_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [bm3d] --
            bm3d_nb_psnrs = []
            M = 10 if B > 10 else B
            for b in range(B):
                bm3d_rec = bm3d.bm3d(mid_img[b].cpu().transpose(0, 2) + 0.5,
                                     sigma_psd=noise_level / 255,
                                     stage_arg=bm3d.BM3DStages.ALL_STAGES)
                bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2)
                b_loss = F.mse_loss(raw_img[b].cpu(),
                                    bm3d_rec,
                                    reduction='none').reshape(1, -1)
                b_loss = torch.mean(b_loss, 1).detach().cpu().numpy()
                bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss))
                bm3d_nb_psnrs.append(bm3d_nb_psnr)
            bm3d_nb_ave = np.mean(bm3d_nb_psnrs)
            bm3d_nb_std = np.std(bm3d_nb_psnrs)

            # -- psnr for aligned + denoised --
            raw_img_repN = raw_img.unsqueeze(1).repeat(1, N, 1, 1, 1)
            mse_loss = F.mse_loss(raw_img_repN,
                                  denoised + 0.5,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_denoised_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [model output image] --
            mse_loss = F.mse_loss(raw_img,
                                  denoised_ave + 0.5,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr = np.mean(mse_to_psnr(mse_loss))
            psnr_std = np.std(mse_to_psnr(mse_loss))

            # -- update losses --
            running_loss /= cfg.log_interval

            # -- alignment MSE --
            align_mse_ave = align_mse_losses / align_mse_count
            align_mse_losses, align_mse_count = 0, 0

            # -- reconstruction MSE --
            rec_mse_ave = rec_mse_losses / rec_mse_count
            rec_mse_losses, rec_mse_count = 0, 0

            # -- reconstruction Dist. --
            rec_ot_ave = rec_ot_losses / rec_ot_count
            rec_ot_losses, rec_ot_count = 0, 0

            # -- write record --
            if use_record:
                info = {
                    'burst': burst_loss,
                    'ave': ave_loss,
                    'ot': rec_ot_ave,
                    'psnr': psnr,
                    'psnr_std': psnr_std
                }
                record_losses = record_losses.append(info, ignore_index=True)

            # -- write to stdout --
            write_info = (epoch, cfg.epochs, batch_idx, len(train_loader),
                          running_loss, psnr, psnr_std, psnr_denoised_ave,
                          psnr_denoised_std, psnr_aligned_ave,
                          psnr_aligned_std, psnr_misaligned_ave,
                          psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std,
                          rec_mse_ave, rec_ot_ave)
            print(
                "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [r-ot]: %.2e"
                % write_info)
            running_loss = 0

        # -- write examples --
        if write_examples and (batch_idx % write_examples_iter) == 0 and (
                batch_idx > 0 or cfg.global_step == 0):
            write_input_output(cfg, model, stacked_burst, aligned, denoised,
                               all_filters, directions)

        if use_timer: clock.toc()
        if use_timer: print(clock)
        cfg.global_step += 1

    # -- remove hooks --
    for hook in align_hooks:
        hook.remove()

    total_loss /= len(train_loader)
    return total_loss, record_losses
Пример #57
0
def evaluate(forward_fn, val_loader, device, opt):
    """
    Evaluates the model on a validation dataset on a number of validation batches.

    Parameters
    ----------
    forward_fn : function
        Forward method of the model, which must be in evaluation mode.
    val_loader : torch.utils.data.DataLoader
        Randomized dataloader for a data.base.VideoDataset dataset.
    device : torch.device
        Device on which operations are performed.
    opt : helper.DotDict
        Contains the training configuration.

    Returns
    -------
    float
        Average negative prediction PSNR.
    """
    inf_len = opt.nt_cond
    assert val_loader is not None and opt.n_iter_test <= len(val_loader)

    n = 0  # Total number of evaluation videos, updated in the validation loop
    global_psnr = 0  # Sum of all computed prediction PSNR
    with torch.no_grad():
        for j, batch in enumerate(val_loader):
            # Stop when the given number of iterations is reached
            if j >= opt.n_iter_test:
                break

            # Data
            x = batch.to(device)
            x_inf = x[:inf_len]
            nt = x.shape[0]
            n_b = x.shape[1]
            n += n_b

            # Perform a given number of predictions per video
            all_x = []
            for _ in range(opt.n_samples_test):
                all_x.append(
                    forward_fn(x_inf, nt, dt=1 / opt.n_euler_steps)[0].cpu())
            all_x = torch.stack(all_x)

            # Sort predictions with respect to PSNR and select the closest one to the ground truth
            x_cpu = x.cpu()
            all_mse = torch.mean(F.mse_loss(all_x,
                                            x_cpu.expand_as(all_x),
                                            reduction='none'),
                                 dim=[4, 5])
            all_psnr = torch.mean(10 * torch.log10(1 / all_mse), dim=[1, 3])
            _, idx_best = all_psnr.max(0)
            x_ = all_x[idx_best, :, torch.arange(n_b).to(device)].transpose(
                0, 1).contiguous().to(device)

            # Compute the final PSNR score
            mse = torch.mean(F.mse_loss(x_, x, reduction='none'), dim=[3, 4])
            psnr = 10 * torch.log10(1 / mse)
            global_psnr += psnr[inf_len:].mean().item() * n_b

    # Average by batch
    return -global_psnr / n
Пример #58
0
 def vae_loss(self, mean, logvar, reconstruction, target, beta=20, c=0.5):
   mse = func.mse_loss(reconstruction, target)
   kld = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp())
   return mse + beta * torch.norm(kld - c, 1)
                batch_size = 1
                seq_len = seq.size()[0]
                seq = seq.view(seq_len, -1).contiguous()
                seq = seq.unsqueeze(dim=1)
                targets = seq[1:]

                optimizer.zero_grad()
                predictions = model(seq)
                losses = [F.binary_cross_entropy(input=pred, target=targets[step]) for step, pred in enumerate(predictions[:-1])]
                loss = sum(losses)
                loss.backward()
                torch.nn.utils.clip_grad_norm(model.parameters(), .25)
                optimizer.step()
                train_losses.append(np.mean([l.data.cpu().numpy() for l in losses]))

                accuracies = [F.mse_loss(input=pred, target=targets[step]) for step, pred in enumerate(predictions[:-1])]
                train_accuracies.append(np.mean([acc.data.cpu().numpy() for acc in accuracies]))

                acc = []
                p = torch.cat(predictions[:-1]).view(batch_size, seq_len - 1, rest).contiguous()
                t = targets.view(batch_size, seq_len - 1, rest).contiguous()
                for i, px in enumerate(p):
                    tx = t[i]
                    acc.append(torch.sum((tx - px) ** 2) / len(px))
                train_accuracies2.append(np.mean([a.data.cpu().numpy() for a in acc]))

        print("Train Loss", np.mean(train_losses))
        print("Train Accuracy", np.mean(train_accuracies))
        print("Train Accuracy2", np.mean(train_accuracies2))
        print("Train time", make_time_units_string(time.time()-_start))
Пример #60
0
def learn_kinematics(model_ensemble: ModelEnsemble,
                     forward_kinematics: Kinematics,
                     configurations: np.ndarray,
                     lr: float = 1e-2,
                     l2_reg: float = .1,
                     n_epochs: int = 10,
                     train_batch_size: int = 100,
                     valid_batch_size: int = 10,
                     log_period: int = 20) -> Tuple[ModelEnsemble, np.ndarray]:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model_ensemble.to(device)

    # Additional Info when using cuda
    if device.type == 'cuda':
        print(f'Using device: {device}')
        print(torch.cuda.get_device_name(0))
        print('Memory Usage:')
        print('Allocated:', round(torch.cuda.memory_allocated(0) / 1024**2, 1),
              'MB')
        print('Cached:   ', round(torch.cuda.memory_reserved(0) / 1024**2, 1),
              'MB')

    train_data, valid_data = shuffle_split(configurations)

    train_set = KinematicsDataset(train_data, forward_kinematics)
    train_generator = DataLoader(train_set,
                                 batch_size=train_batch_size,
                                 shuffle=True)

    valid_set = KinematicsDataset(valid_data, forward_kinematics)
    valid_generator = DataLoader(valid_set,
                                 batch_size=valid_batch_size,
                                 shuffle=True)

    optimizers = []
    for model in model_ensemble:
        optimizers.append(
            torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2_reg))

    loss_history = []

    n_total_updates = 0
    for epoch in range(n_epochs):
        n_epoch_updates = 0
        for configurations, true_states in train_generator:
            configurations = configurations.to(device)
            true_states = true_states.to(device)

            for model, optimizer in zip(model_ensemble, optimizers):
                optimizer.zero_grad()

                predicted_states = model(configurations)

                loss = F.mse_loss(predicted_states, true_states).mean()
                loss.backward()

                optimizer.step()

            if not n_epoch_updates % log_period:
                with torch.no_grad():
                    valid_count = 0
                    running_loss = 0
                    for valid_configs, valid_true_states in valid_generator:
                        valid_configs = valid_configs.to(device)
                        valid_true_states = valid_true_states.to(device)

                        batch_loss = 0
                        for model in model_ensemble.models:
                            predicted_states = model(valid_configs)
                            batch_loss += F.mse_loss(
                                predicted_states,
                                valid_true_states).mean().detach().to('cpu')

                        running_loss += batch_loss / model_ensemble.n_models

                        valid_count += 1

                    mean_loss = running_loss / valid_count
                    print(
                        f'Epoch = {epoch + 1}, Total Updates = {n_total_updates}, Mean loss = {mean_loss}'
                    )
                    loss_history.append([n_total_updates, mean_loss])

            n_epoch_updates += 1
            n_total_updates += 1

    loss_history = np.asarray(loss_history)

    return model_ensemble, loss_history