예제 #1
0
def dp_sgd_backward(dis, real_imgs, fake_imgs, device, clip_norm, noise_factor,
                    dis_opt):
    """
  since only one part of the loss depends on the data, we compute its gradients separately and noise them up
  in order to maintain similar gradient sizes, gradients for the other loss are clipped to the same per sample norm
  """
    # real data loss first:
    params = list(dis.parameters())
    loss_real = -pt.mean(dis(real_imgs))
    with backpack(BatchGrad(), BatchL2Grad()):
        loss_real.backward(retain_graph=True)

    squared_param_norms_real = [
        p.batch_l2 for p in params
    ]  # first we get all the squared parameter norms...
    global_norms_real = pt.sqrt(
        pt.sum(pt.stack(squared_param_norms_real),
               dim=0))  # ...then compute the global norms...
    global_clips_real = pt.clamp_max(
        clip_norm / global_norms_real,
        1.)  # ...and finally get a vector of clipping factors

    perturbed_grads = []
    for idx, param in enumerate(params):
        clipped_sample_grads = param.grad_batch * expand_vector(
            global_clips_real, param.grad_batch)
        clipped_grad = pt.sum(clipped_sample_grads,
                              dim=0)  # after clipping we sum over the batch

        noise_sdev = noise_factor * 2 * clip_norm  # gaussian noise standard dev is computed (sensitivity is 2*clip)...
        perturbed_grad = clipped_grad + pt.randn_like(
            clipped_grad, device=device) * noise_sdev  # ...and applied
        perturbed_grads.append(perturbed_grad)  # store perturbed grads

    dis_opt.zero_grad()
    # now add fake data loss gradients:
    loss_fake = pt.mean(dis(fake_imgs))
    with backpack(BatchGrad(), BatchL2Grad()):
        loss_fake.backward()

    squared_param_norms_fake = [
        p.batch_l2 for p in params
    ]  # first we get all the squared parameter norms...
    global_norms_fake = pt.sqrt(
        pt.sum(pt.stack(squared_param_norms_fake),
               dim=0))  # ...then compute the global norms...
    global_clips_fake = pt.clamp_max(
        clip_norm / global_norms_fake,
        1.)  # ...and finally get a vector of clipping factors

    for idx, param in enumerate(params):
        clipped_sample_grads = param.grad_batch * expand_vector(
            global_clips_fake, param.grad_batch)
        clipped_grad = pt.sum(clipped_sample_grads,
                              dim=0)  # after clipping we sum over the batch

        param.grad = clipped_grad + perturbed_grads[idx]

    ld = loss_real.item() + loss_fake.item()
    return global_norms_real, global_clips_real, global_norms_fake, global_clips_fake, ld
예제 #2
0
def polar_indices(
        positions,  # (N, L_s, L, 2)
        nray,
        nring,
        inner_radius):  # (N, L_s, L), (N, L_s, L), (N, L_s, L), (N, L_s, L)
    distances = torch.sqrt(positions[:, :, :, 0]**2 + positions[:, :, :, 1]**2)
    distance_indices = torch.clamp(distances / inner_radius,
                                   min=0,
                                   max=nring - 1).floor().long()
    angles = torch.atan2(positions[:, :, :, 1], positions[:, :, :,
                                                          0]) + math.pi
    # There is one angle value that can result in index of exactly nray, clamp it to nray-1
    angular_indices = torch.clamp_max(
        (angles / (2 * math.pi) * nray).floor().long(), nray - 1)

    distance_offsets = torch.clamp_max(distances / inner_radius -
                                       distance_indices.float() - 0.5,
                                       max=2)
    angular_offsets = angles / (2 *
                                math.pi) * nray - angular_indices.float() - 0.5

    assert angular_indices.min(
    ) >= 0, f'Negative angular index: {angular_indices.min()}'
    assert angular_indices.max(
    ) < nray, f'invalid angular index: {angular_indices.max()} >= {nray}'
    assert distance_indices.min(
    ) >= 0, f'Negative distance index: {distance_indices.min()}'
    assert distance_indices.max(
    ) < nring, f'invalid distance index: {distance_indices.max()} >= {nring}'

    return distance_indices, angular_indices, distance_offsets, angular_offsets
예제 #3
0
def fixed_circle_loss_clean(output, target, mask, num_classes,
                            softplus: torch.nn.Softplus):
    """
    output: (B, N)
    target: (B, X)
    mask: (B, N)
    num_classes = N

    loss = log(1 + \sum_i exp(s^{neg}_i) \sum_j exp(s^{pos}_j))

    \gamma = 1, m = 0

    """
    output = output.float()

    # seq_len = output.size(1)
    target_mask = (target == -1)
    target[target_mask] = num_classes

    one_hot_target = F.one_hot(target, num_classes + 1)
    one_hot_target = one_hot_target.sum(dim=1)
    one_hot_target = one_hot_target[:, :-1].to(dtype=output.dtype)

    mask = mask.to(dtype=output.dtype)
    all_mask = (one_hot_target.sum(dim=-1) == 0)

    mask_for_pos = torch.clamp_max(1 - one_hot_target + mask, max=1.0)
    mask_for_neg = torch.clamp_max(one_hot_target + mask, max=1.0)

    _flag = -1e12

    ap = torch.clamp_min(1 - output.detach(), min=0.)
    an = torch.clamp_min(output.detach(), min=0.)

    delta_p = 1
    delta_n = 0

    output = (1 - 2 * one_hot_target) * output  # Positive: -1 // Negative: +1

    logit_pos = ap * (delta_p + output) + mask_for_pos * _flag

    x = logit_pos.max(dim=-1, keepdim=True)[0].detach()
    x = torch.relu_(x)

    logit_pos = logit_pos - x
    # assert output_pos.size() == output.size(), (output_pos.size(), output.size())

    logit_neg = an * (output - delta_n) + mask_for_neg * _flag

    y = logit_neg.max(dim=-1, keepdim=True)[0].detach()
    y = torch.relu_(y)

    logit_neg = logit_neg - y

    loss = softplus(x + y + torch.logsumexp(logit_pos, dim=-1) +
                    torch.logsumexp(logit_neg, dim=-1))

    masked_loss = loss.masked_fill(all_mask, 0.)

    return masked_loss
예제 #4
0
 def bound_loss(self, mu):
     if self.bounds_loss_coef is not None:
         soft_bound = 1.1
         mu_loss_high = torch.clamp_max(mu - soft_bound, 0.0)**2
         mu_loss_low = torch.clamp_max(-mu + soft_bound, 0.0)**2
         b_loss = (mu_loss_low + mu_loss_high).sum(axis=-1)
     else:
         b_loss = 0
     return b_loss
예제 #5
0
파일: losses.py 프로젝트: yiyayada/nvae
def discretized_mix_logistic_loss(y_hat: torch.Tensor,
                                  y: torch.Tensor,
                                  num_classes=256,
                                  log_scale_min=-7.0):
    """Discretized mix of logistic distributions loss.

    Note that it is assumed that input is scaled to [-1, 1]



    :param y_hat: Tensor. shape=(batch_size, 3 * num_mixtures * img_channels, height, width), predict output.
    :param y: Tensor. shape=(batch_size, img_channels, height, width), Target.
    :return: Tensor loss
    """

    # unpack parameters, [batch_size, num_mixtures * img_channels, height, width] x 3
    logit_probs, means, log_scales = y_hat.chunk(3, dim=1)
    log_scales = torch.clamp_max(log_scales, log_scale_min)

    num_mixtures = y_hat.size(1) // y.size(1) // 3

    B, C, H, W = y.shape
    y = y.unsqueeze(1).repeat(1, num_mixtures, 1, 1,
                              1).permute(0, 2, 1, 3, 4).reshape(B, -1, H, W)

    centered_y = y - means
    inv_stdv = torch.exp(-log_scales)
    plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1))
    cdf_plus = torch.sigmoid(plus_in)
    min_in = inv_stdv * (centered_y - 1. / (num_classes - 1))
    cdf_min = torch.sigmoid(min_in)

    log_cdf_plus = plus_in - F.softplus(plus_in)
    log_one_minus_cdf_min = -F.softplus(min_in)

    # probability for all other cases
    cdf_delta = cdf_plus - cdf_min

    mid_in = inv_stdv * centered_y
    log_pdf_mid = min_in - log_scales - 2. * F.softplus(mid_in)

    log_probs = torch.where(
        y < -0.999, log_cdf_plus,
        torch.where(
            y > 0.999, log_one_minus_cdf_min,
            torch.where(cdf_delta > 1e-5, torch.clamp_max(cdf_delta, 1e-12),
                        log_pdf_mid - np.log((num_classes - 1) / 2))))

    # (batch_size, num_mixtures * img_channels, height, width)
    log_probs = log_probs + F.softmax(log_probs, dim=1)

    log_probs = [
        log_sum_exp(log_prob) for log_prob in log_probs.chunk(y.size(1), dim=1)
    ]
    log_probs = reduce(lambda a, b: a + b, log_probs)

    return -torch.sum(log_probs)
예제 #6
0
def _binary_cross_entropy(pred: Tensor, target: Tensor) -> Tensor:
    """二元交叉熵损失(F.binary_cross_entropy())

    :param pred: shape = (N,)
    :param target: shape = (N,) torch.float32
    :return: shape = ()"""
    # torch.mean(-torch.log(pred) * target + -torch.log(1 - pred) * (1 - target))
    return torch.mean(
        torch.clamp_max(-torch.log(pred), 100) * target +  # 防止inf
        torch.clamp_max(-torch.log(1 - pred), 100) * (1 - target))
예제 #7
0
파일: hmc.py 프로젝트: jongharyu/BDMC
def accept_reject(current_z, current_v,
                  z, v,
                  epsilon,
                  accept_hist, hist_len,
                  U, K=lambda v: torch.sum(v * v, 1)):
  """Accept/reject based on Hamiltonians for current and propose.

  Args:
      current_z: position *before* leap-frog steps
      current_v: speed *before* leap-frog steps
      z: position *after* leap-frog steps
      v: speed *after* leap-frog steps
      epsilon: step size of leap-frog.
      U: function to compute potential energy
      K: function to compute kinetic energy
  """
  current_Hamil = K(current_v) + U(current_z)
  propose_Hamil = K(v) + U(z)

  prob = torch.clamp_max(torch.exp(current_Hamil - propose_Hamil), 1.)

  with torch.no_grad():
    uniform_sample = torch.rand(prob.size()).cuda()
    accept = (prob > uniform_sample).float().cuda()
    z = z.mul(accept.view(-1, 1)) + current_z.mul(1. - accept.view(-1, 1))

    accept_hist = accept_hist.add(accept)
    criteria = (accept_hist / hist_len > 0.65).float().cuda()
    adapt = 1.02 * criteria + 0.98 * (1. - criteria)
    epsilon = epsilon.mul(adapt).clamp(1e-4, .5)

  z.requires_grad_()

  return z, epsilon, accept_hist
예제 #8
0
    def __call__(self, model: nn.Module, images: torch.Tensor,
                 labels: torch.Tensor) -> AdversaryOutput:
        pixel_epsilon = self.epsilon * 255
        n_iters = round(min(
            pixel_epsilon + 4, 1.25 *
            pixel_epsilon))  # according to the policy in the reference paper

        lo = torch.clamp_min(images - self.epsilon, 0)
        hi = torch.clamp_max(images + self.epsilon, 1)

        result = images.clone()
        result.requires_grad = True

        for _ in range(n_iters):
            if result.grad is not None:
                result.grad.detach_()
                result.grad.zero_()

            loss = self.compute_objective(model, result, labels, "mean")
            loss.backward()

            with torch.no_grad():
                result += self.step_size * torch.sign(result.grad)
                clamp_min_tensor(result, lo)
                clamp_max_tensor(result, hi)

        result.requires_grad = False
        return AdversaryOutput(result, result - images)
예제 #9
0
def nn_transformation(model):
    b_new_list = [None for _ in range(len(model.layers))]
    act_shift_list = [None for _ in range(len(model.layers))]

    nn_model = model.get_nn_net()
    for i in range(len(model.layers)):
        cur_layer = model.layers[i]

        w_neg = T.clamp_max(cur_layer.weight, 0)
        b_tilde = cur_layer.bias - model.alpha[i] * T.sum(T.abs(w_neg), dim=1)
        b_new, act_shift = calc_b_new(cur_layer, b_tilde)

        b_new_list[i] = b_new
        act_shift_list[i] = act_shift

        w_neg_abs = T.abs(w_neg)
        w_pos = T.clamp_min(cur_layer.weight, 0)

        nn_model.add_layer(
            NNLinear(
                cur_layer.in_features,
                cur_layer.out_features,
                w_pos,
                w_neg_abs,
                b_new,
                model.alpha[i],
                act_shift,
            ))

    return nn_model
예제 #10
0
    def step(self):
        """Performs a single optimization step.

        The function expects the gradients to have been computed by BackPACK
        and the parameters to have a ``batch_l2`` and ``grad_batch`` attribute.
        """
        l2_norms_all_params_list = []
        for group in self.param_groups:
            for p in group["params"]:
                l2_norms_all_params_list.append(p.batch_l2)

        l2_norms_all_params = torch.stack(l2_norms_all_params_list)
        total_norms = torch.sqrt(torch.sum(l2_norms_all_params, dim=0))
        scaling_factors = torch.clamp_max(total_norms / self.max_norm, 1.0)

        for group in self.param_groups:
            for p in group["params"]:
                clipped_grads = p.grad_batch * make_broadcastable(
                    scaling_factors, p.grad_batch)
                clipped_grad = torch.sum(clipped_grads, dim=0)

                noise_magnitude = self.stddev * self.max_norm
                noise = torch.randn_like(clipped_grad) * noise_magnitude

                perturbed_update = clipped_grad + noise

                p.data.add_(-self.lr * perturbed_update)
예제 #11
0
    def forward(self, state):
        x = self.linear1(state)
        x = F.gelu(x)
        x = self.linear2(x)
        x = F.gelu(x)
        x = self.linear3(x)

        x_mean = F.gelu(x)
        x_mean = self.linear_mean_4(x_mean)
        x_mean = F.gelu(x_mean)
        x_mean = self.linear_mean_5(x_mean)
        x_mean = F.gelu(x_mean)

        x_std = F.gelu(x)
        x_std = self.linear_std_4(x_std)
        x_std = F.gelu(x_std)
        x_std = self.linear_std_5(x_std)
        x_std = F.gelu(x_std)

        mean = self.mean_layer(x_mean)
        log_std = self.log_std_layer(x_std)
        log_std = torch.clamp_min(self.log_std_max*torch.tanh(log_std/self.denominator),0) + \
                  torch.clamp_max(-self.log_std_min * torch.tanh(log_std / self.denominator), 0)
        #log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
        return mean, log_std
예제 #12
0
    def _sample_embd(norm_embd, labels, batch_size, adaptive_sampling,
                     sampling_angle_std):
        with torch.no_grad():
            unit_directions = F.normalize(torch.randn_like(norm_embd), dim=1)
            dot_prod = torch.sum(norm_embd * unit_directions,
                                 dim=1,
                                 keepdim=True)
            orthogonal_directions = unit_directions - dot_prod * norm_embd

            if adaptive_sampling and labels is not None:
                all_angle_std = sampling_angle_std.expand(batch_size, -1)
                class_indices = torch.arange(batch_size, device=labels.device)
                angle_std = all_angle_std[class_indices, labels].view(-1, 1)
            else:
                angle_std = sampling_angle_std

            angles = angle_std * torch.randn_like(dot_prod)
            alpha = torch.clamp_max(
                torch.where(angles > 0.0, angles, torch.neg(angles)),
                0.5 * np.pi)
            cos_alpha = torch.cos(alpha)
            sin_alpha = torch.sin(alpha)

        out_norm_embd = cos_alpha * norm_embd + sin_alpha * orthogonal_directions

        return out_norm_embd
예제 #13
0
def overlay_img_with_density(img_path, density_map_path, output_path):
    """
    combine output density map with image to create the red heatmap overlay
    :param img_path:
    :param density_map_path: output .torch of density map
    :param output_path:
    :return:
    """
    img_tensor = read_image(img_path)
    density_map_tensor = torch.load(density_map_path)

    print(img_tensor.shape)
    print(density_map_tensor.shape)
    print(density_map_tensor.sum())
    density_map_tensor = torch.from_numpy(density_map_tensor).unsqueeze(
        dim=0).unsqueeze(dim=0)
    print("density_map_tensor.shape",
          density_map_tensor.shape)  # torch.Size([1, 1, 46, 82])
    upsampling_density_map_tensor = nn.functional.interpolate(
        density_map_tensor, scale_factor=8) / 64

    overlay_density_map = img_tensor.detach().clone()
    upsampling_density_map_tensor = (
        upsampling_density_map_tensor.squeeze(dim=0) /
        upsampling_density_map_tensor.max() * 255)
    overlay_density_map[0] = torch.clamp_max(
        img_tensor[0] + upsampling_density_map_tensor[0] * 2, max=255)

    write_jpeg(overlay_density_map.type(torch.uint8), output_path, quality=100)
예제 #14
0
def _train_step(
    model,
    dataloader,
    loss_criterion,
    optimiser,
    lr_scheduler,
    l2_norm_clip=0,
    use_gpu=False,
):
    """One epoch of training"""
    outputs, losses = [], []
    for x, y in dataloader:
        if use_gpu:
            x, y = x.cuda(), y.cuda()

        out = model(x)
        loss = loss_criterion(out, y)
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()
        if l2_norm_clip > 0:
            # Clip weights in final layer
            with torch.no_grad():
                norms = torch.norm(model.fc.out.weight, dim=1, keepdim=True)
                norms_clipped = torch.clamp_max(norms, l2_norm_clip)

                # Renormalise weights
                model.fc.out.weight.div_(norms).mul_(norms_clipped)

        outputs.append(out)
        losses.append(loss.item())

    lr_scheduler.step()
    return outputs, losses
예제 #15
0
def regulate_len(durations, enc_out, pace: float = 1.0,
                 mel_max_len: Optional[int] = None):
    """If target=None, then predicted durations are applied"""
    reps = torch.round(durations.float() / pace).long()
    dec_lens = reps.sum(dim=1)

    max_len = dec_lens.max()
    bsz, _, hid = enc_out.size()

    reps_padded = torch.cat([reps, (max_len - dec_lens)[:, None]], dim=1)
    pad_vec = torch.zeros(bsz, 1, hid, dtype=enc_out.dtype,
                          device=enc_out.device)

    enc_rep = torch.cat([enc_out, pad_vec], dim=1)
    enc_rep = torch.repeat_interleave(
        enc_rep.view(-1, hid), reps_padded.view(-1), dim=0
    ).view(bsz, -1, hid)

    # enc_rep = pad_sequence([torch.repeat_interleave(o, r, dim=0)
    #                         for o, r in zip(enc_out, reps)],
    #                        batch_first=True)
    if mel_max_len is not None:
        enc_rep = enc_rep[:, :mel_max_len]
        dec_lens = torch.clamp_max(dec_lens, mel_max_len)
    return enc_rep, dec_lens
예제 #16
0
    def __call__(self, model: nn.Module, images: torch.Tensor,
                 labels: torch.Tensor) -> AdversaryOutput:
        lo = torch.clamp_min(images - self.epsilon, 0)
        hi = torch.clamp_max(images + self.epsilon, 1)

        result = torch.clamp(
            images + random_float_like(images, -self.epsilon, self.epsilon), 0,
            1)
        result.requires_grad = True

        for _ in range(self.n_iters):
            if result.grad is not None:
                result.grad.detach_()
                result.grad.zero_()

            loss = self.compute_objective(model, result, labels, "mean")
            loss.backward()

            with torch.no_grad():
                result += self.step_size * torch.sign(result.grad)
                clamp_min_tensor(result, lo)
                clamp_max_tensor(result, hi)

        result.requires_grad = False
        return AdversaryOutput(result, result - images)
예제 #17
0
 def expmap0(self, u, c):
     sqrt_c = c**0.5
     u_norm = torch.clamp_max(
         torch.clamp_min(u.norm(dim=-1, p=2, keepdim=True), self.min_norm),
         self.max_norm)
     gamma_1 = tanh(sqrt_c * u_norm) * u / (sqrt_c * u_norm)
     return gamma_1
예제 #18
0
파일: helpers.py 프로젝트: manneh/NeMo
def regulate_len(durations, enc_out, pace=1.0, mel_max_len=None):
    """A function that takes predicted durations per encoded token, and repeats enc_out according to the duration.
    NOTE: durations.shape[1] == enc_out.shape[1]

    Args:
        durations (torch.tensor): A tensor of shape (batch x enc_length) that represents how many times to repeat each
            token in enc_out.
        enc_out (torch.tensor): A tensor of shape (batch x enc_length x enc_hidden) that represents the encoded tokens.
        pace (float): The pace of speaker. Higher values result in faster speaking pace. Defaults to 1.0.
        max_mel_len (int): The maximum length above which the output will be removed. If sum(durations, dim=1) >
            max_mel_len, the values after max_mel_len will be removed. Defaults to None, which has no max length.
    """

    dtype = enc_out.dtype
    reps = durations.float() / pace
    reps = (reps + 0.5).long()
    dec_lens = reps.sum(dim=1)

    max_len = dec_lens.max()
    reps_cumsum = torch.cumsum(torch.nn.functional.pad(reps, (1, 0, 0, 0),
                                                       value=0.0),
                               dim=1)[:, None, :]
    reps_cumsum = reps_cumsum.to(dtype)

    range_ = torch.arange(max_len).to(enc_out.device)[None, :, None]
    mult = (reps_cumsum[:, :, :-1] <= range_) & (reps_cumsum[:, :, 1:] >
                                                 range_)
    mult = mult.to(dtype)
    enc_rep = torch.matmul(mult, enc_out)

    if mel_max_len:
        enc_rep = enc_rep[:, :mel_max_len]
        dec_lens = torch.clamp_max(dec_lens, mel_max_len)

    return enc_rep, dec_lens
예제 #19
0
def ascend_txt(model, perceptor, t, nom, lats, la, lb):
    out = model(lats())
    cutn, sideX, sideY = out.size()[1:]
    p_s = []
    for ch in range(cutn):
        size = int(sideX *
                   torch.zeros(1, ).normal_(mean=.8, std=.3).clip(.5, .95))
        offsetx = torch.randint(0, sideX - size, ())
        offsety = torch.randint(0, sideY - size, ())
        apper = out[:, :, offsetx:offsetx + size, offsety:offsety + size]
        apper = torch.nn.functional.interpolate(apper, (224, 224),
                                                mode='nearest')
        p_s.append(apper)
    into = torch.cat(p_s, 0)
    into = nom((into + 1) / 2)
    iii = perceptor.encode_image(into)

    llls = lats()
    lat_l = torch.abs(1 - torch.std(llls, dim=1)).mean() + torch.abs(
        torch.mean(llls)).mean() + 4 * torch.clamp_max(
            torch.square(llls).mean(), 1)

    for array in llls:
        mean = torch.mean(array)
        diffs = array - mean
        var = torch.mean(torch.pow(diffs, 2.0))
        std = torch.pow(var, 0.5)
        zscores = diffs / std
        skews = torch.mean(torch.pow(zscores, 3.0))
        kurtoses = torch.mean(torch.pow(zscores, 4.0)) - 3.0
        lat_l = lat_l + torch.abs(kurtoses) / llls.shape[0] + torch.abs(
            skews) / llls.shape[0]

    return la * lat_l, -lb * torch.cosine_similarity(t, iii, dim=-1)
예제 #20
0
 def depth_map_to_rgbimg(depth_map):
     depth_map = np.asarray(
         np.squeeze(
             (255 - torch.clamp_max(depth_map * 4, 250)).byte().numpy()),
         np.uint8)
     depth_map = np.asarray(cv2.cvtColor(depth_map, cv2.COLOR_GRAY2RGB),
                            np.uint8)
     return depth_map
예제 #21
0
    def run_train_step(self, tensor_input, tensor_output, tensor_focal):
        tensor_input, tensor_output = tensor_input.to(
            device), tensor_output.to(device)
        # Get Models prediction and calculate loss
        model_output, depth2, depth4, depth8 = self.bts(
            tensor_input, tensor_focal)

        loss = self.criterion(model_output,
                              tensor_output) * 1 / self.backprop_frequency

        if USE_APEX:
            with apex.amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        if self.current_step % self.backprop_frequency == 0:  # Make update once every x steps
            torch.nn.utils.clip_grad_norm_(self.bts.parameters(), 5)
            self.optimizer.step()
            self.optimizer.zero_grad()

        if self.current_step % 100 == 0:
            self.writer.add_scalar(
                "Loss",
                loss.item() * self.backprop_frequency / tensor_input.shape[0],
                self.current_step)

        if self.current_step % 1000 == 0:
            img = tensor_input[0].detach().transpose(0, 2).transpose(
                0, 1).cpu().numpy().astype(np.uint8)
            self.writer.add_image("Input", img, self.current_step, None, "HWC")

            visual_result = (255 - torch.clamp_max(
                torchvision.utils.make_grid(
                    [tensor_output[0], model_output[0]]) * 5, 250)).byte()

            self.writer.add_image("Output/Prediction", visual_result,
                                  self.current_step)
            depths = [depth2[0], depth4[0], depth8[0]]
            depths = [depth * MAX_DEPTH for depth in depths]
            depth_visual = (255 - torch.clamp_max(
                torchvision.utils.make_grid(depths) * 5, 250)).byte()

            self.writer.add_image("Depths", depth_visual, self.current_step)

        self.current_step += 1
예제 #22
0
    def train(self, train_buffer, val_buffer, callback_fn):
        for epoch in range(self.args['max_epoch']):
            for i in range(self.args['steps_per_epoch']):
                batch_data = train_buffer.sample(self.batch_size)
                batch_data.to_torch(device=self.device)
                obs = batch_data['obs']
                action = batch_data['act']
                next_obs = batch_data['obs_next']
                reward = batch_data['rew']
                done = batch_data['done'].float()

                # update critic
                p = self.critic(obs, action)
                next_action = self.actor_target.get_action(next_obs)
                target_p = self.critic_target.get_target(
                    next_obs, next_action, reward, self.gamma * (1 - done))
                critic_loss = -(target_p * torch.log(p + 1e-8)).mean()

                self.critic_optim.zero_grad()
                critic_loss.backward()
                self.critic_optim.step()

                # update actor
                action_dist = self.actor(obs)
                log_prob = action_dist.log_prob(action)
                actions = torch.stack(
                    [action_dist.sample() for _ in range(self.m)], dim=0)
                repeat_obs = torch.repeat_interleave(obs.unsqueeze(0), self.m,
                                                     0)
                _, values = self.critic(repeat_obs, actions, with_q=True)
                _, value = self.critic(obs, action, with_q=True)

                if self.advantage_mode == 'mean':
                    advantage = value - values.mean(dim=0)
                elif self.advantage_mode == 'max':
                    advantage = value - values.max(dim=0)[0]

                if self.weight_mode == 'exp':
                    weight = torch.exp(advantage / self.beta)
                elif self.weight_mode == 'binary':
                    weight = (advantage > 0).float()

                weight = torch.clamp_max(weight, 20).detach()
                actor_loss = -torch.mean(weight * log_prob)

                self.actor_optim.zero_grad()
                actor_loss.backward()
                self.actor_optim.step()

                if i % self.args['update_frequency']:
                    self._sync_weight(self.critic_target, self.critic, 1.0)
                    self._sync_weight(self.actor_target, self.actor, 1.0)

            res = callback_fn(self.get_policy())

            self.log_res(epoch, res)

        return self.get_policy()
예제 #23
0
    def vtrace(self, target_policies_action, behavior_policies_action, rewards, dones, values, next_values):
        # * v-trace algorithm
        rho_s = torch.exp(target_policies_action.log() - behavior_policies_action.log())
        clip_rho_s = torch.clamp_max(rho_s, self.rho)
        c_s = torch.clamp_max(rho_s, self.c)
        deltas = clip_rho_s * (rewards + self.gamma * (1 - dones) * next_values - values)
        trajectory_size = rho_s.size(1)
        acc = 0
        vs_minus_v_xs = []
        for i in range(trajectory_size - 1, -1, -1):
            # * origin paper:
            # * acc = deltas[:, i] + self.gamma * (1 - dones)[:, i] * c_s[:, i] * (acc - next_values[:, i])
            acc = deltas[:, i] + self.gamma * (1 - dones)[:, i] * c_s[:, i] * acc
            vs_minus_v_xs.append(acc.view(-1, 1))

        vs_minus_v_xs = torch.cat(vs_minus_v_xs[::-1], dim=1).unsqueeze(-1)
        vs = vs_minus_v_xs + values
        return clip_rho_s, vs
예제 #24
0
def weighted_binary_focal_loss(pred, target, alpha=0.25, gamma=2.):
    """f(x) = alpha * (1 - x)^a * -ln(sigmoid(pred))

    :param pred: shape = (N,)
    :param target: shape = (N,)
    :param alpha: float
        The weight of the negative sample and the positive sample. (alpha * positive + (1 - alpha) * negative)
    :param gamma: float
    :return: shape = ()"""

    # target == -1. It's neither a positive sample nor a negative sample.
    return torch.sum(
        torch.where(
            target == -1, torch.tensor(0., device=target.device),
            alpha * (1 - pred)**gamma * target *
            torch.clamp_max(-torch.log(pred), 100) +
            (1 - alpha) * pred**gamma *
            (1 - target) * torch.clamp_max(-torch.log(1 - pred), 100)))
예제 #25
0
def bucketize(h_ts, t_ts, tau=21600, vmax=25):
    diff = (t_ts.unsqueeze(-1) - h_ts).float() / tau
    diff = diff.masked_fill_(diff < 0, -1) + 1
    bucket_ids = torch.floor(torch.log2(diff))
    bucket_ids += 1
    # bucket_ids = bucket_ids.masked_fill(~torch.isfinite(bucket_ids), 0)  # padding index = 0
    bucket_ids = torch.clamp_min(bucket_ids, 0)
    bucket_ids = torch.clamp_max(bucket_ids.long(), vmax)
    return bucket_ids
예제 #26
0
    def forward(self, pred, label):
        one_hot = label > 0.5
        sample_weight = label != self._ignore_label

        if not self._from_logits:
            pred = torch.sigmoid(pred)

        alpha = torch.where(one_hot, self._alpha * sample_weight,
                            (1 - self._alpha) * sample_weight)
        pt = torch.where(sample_weight, 1.0 - torch.abs(label - pred),
                         torch.ones_like(pred))

        beta = (1 - pt)**self._gamma

        sw_sum = torch.sum(sample_weight, dim=(-2, -1), keepdim=True)
        beta_sum = torch.sum(beta, dim=(-2, -1), keepdim=True)
        mult = sw_sum / (beta_sum + self._eps)
        if self._detach_delimeter:
            mult = mult.detach()
        beta = beta * mult
        if self._max_mult > 0:
            beta = torch.clamp_max(beta, self._max_mult)

        with torch.no_grad():
            ignore_area = torch.sum(label == self._ignore_label,
                                    dim=tuple(range(
                                        1, label.dim()))).cpu().numpy()
            sample_mult = torch.mean(mult, dim=tuple(range(
                1, mult.dim()))).cpu().numpy()
            if np.any(ignore_area == 0):
                self._k_sum = 0.9 * self._k_sum + 0.1 * sample_mult[
                    ignore_area == 0].mean()

                beta_pmax, _ = torch.flatten(beta, start_dim=1).max(dim=1)
                beta_pmax = beta_pmax.mean().item()
                self._m_max = 0.8 * self._m_max + 0.2 * beta_pmax

        loss = -alpha * beta * torch.log(
            torch.min(pt + self._eps,
                      torch.ones(1, dtype=torch.float).to(pt.device)))
        loss = self._weight * (loss * sample_weight)

        if self._size_average:
            bsum = torch.sum(sample_weight,
                             dim=misc.get_dims_with_exclusion(
                                 sample_weight.dim(), self._batch_axis))
            loss = torch.sum(loss,
                             dim=misc.get_dims_with_exclusion(
                                 loss.dim(),
                                 self._batch_axis)) / (bsum + self._eps)
        else:
            loss = torch.sum(loss,
                             dim=misc.get_dims_with_exclusion(
                                 loss.dim(), self._batch_axis))

        return loss
예제 #27
0
def nn_transformation(model):
    b_new_list = [None for _ in range(len(model.layers))]
    act_shift_list = [None for _ in range(len(model.layers))]

    nn_model = model.get_nn_net()
    for i in range(len(model.layers)):
        cur_layer = model.layers[i]
        classname = cur_layer.__class__.__name__

        if classname.find('Linear') != -1:
            sum_dim = 1
        elif classname.find('Conv2d') != -1:
            sum_dim = (1, 2, 3)

        w_neg = T.clamp_max(cur_layer.weight, 0)
        b_tilde = cur_layer.bias - model.alpha[i] * T.sum(T.abs(w_neg),
                                                          dim=sum_dim)
        b_new, act_shift = calc_b_new(cur_layer, b_tilde)

        b_new_list[i] = b_new
        act_shift_list[i] = act_shift

        w_neg_abs = T.abs(w_neg)
        w_pos = T.clamp_min(cur_layer.weight, 0)

        if classname.find('Linear') != -1:
            new_layer = NNLinear(
                cur_layer.in_features,
                cur_layer.out_features,
                w_pos,
                w_neg_abs,
                b_new,
                model.alpha[i],
                act_shift,
            )
        elif classname.find('Conv2d') != -1:
            new_layer = NNConv2d(
                cur_layer.in_channels,
                cur_layer.out_channels,
                cur_layer.kernel_size,
                w_pos,
                w_neg_abs,
                b_new,
                model.alpha[i],
                act_shift,
                stride=cur_layer.stride,
                padding=cur_layer.padding,
                dilation=cur_layer.dilation,
                groups=cur_layer.groups,
                padding_mode=cur_layer.padding_mode,
            )
        else:
            raise Exception("This layer type cannot be converted")
        nn_model.add_layer(new_layer)
    return nn_model
예제 #28
0
def sample_pdf(bins, weights, N_importance, det=False, eps=1e-5):
    """
    Sample @N_importance samples from @bins with distribution defined by @weights.

    Inputs:
        bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2"
        weights: (N_rays, N_samples_)
        N_importance: the number of samples to draw from the distribution
        det: deterministic or not
        eps: a small number to prevent division by zero

    Outputs:
        samples: the sampled samples
    """
    N_rays, N_samples_ = weights.shape
    weights = weights + eps  # prevent division by zero (don't do inplace op!)
    pdf = weights / torch.sum(weights, -1,
                              keepdim=True)  # (N_rays, N_samples_)
    cdf = torch.cumsum(
        pdf, -1)  # (N_rays, N_samples), cumulative distribution function
    cdf = torch.cat([torch.zeros_like(cdf[:, :1]), cdf],
                    -1)  # (N_rays, N_samples_+1)
    # padded to 0~1 inclusive
    if det:
        u = torch.linspace(0, 1, N_importance, device=bins.device)
        u = u.expand(N_rays, N_importance)
    else:
        u = torch.rand(N_rays, N_importance, device=bins.device)
    u = u.contiguous()

    inds = searchsorted(cdf, u, side='right')
    below = torch.clamp_min(inds - 1, 0)
    above = torch.clamp_max(inds, N_samples_)

    inds_sampled = torch.stack([below, above],
                               -1).view(N_rays, 2 * N_importance)

    cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2)
    bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2)

    print(u[0], inds[0], cdf[0], below[0], above[0])

    denom = cdf_g[..., 1] - cdf_g[..., 0]
    denom[
        denom <
        eps] = 1  # denom equals 0 means a bin has weight 0, in which case it will not be sampled
    # anyway, therefore any value for it is fine (set to 1 here)

    samples = bins_g[..., 0] + (u - cdf_g[..., 0]) / denom * (bins_g[..., 1] -
                                                              bins_g[..., 0])

    print(inds_sampled[0], cdf_g[0], bins_g[0], denom[0])
    print(samples[0])
    return samples
예제 #29
0
 def get_deformation(self, coords, feat, c=None, h_sample=1e-3):
     coords_neighbor = coords + \
         (torch.rand_like(coords) * h_sample - (h_sample / 2.))
     coords_neighbor = torch.clamp_max(coords_neighbor, 0, 1)
     weighted_feat = coords[:, :, None] * feat
     coords_encoded = self.position_encoding(coords_neighbor, self.B)
     deform_neighbor = self.decoder(torch.cat(
         [coords_encoded, weighted_feat], dim=-1),
                                    c=c,
                                    only_displacment=True)
     return deform_neighbor
예제 #30
0
def regulate_len(durations, enc_out, pace=1.0, mel_max_len=None):
    """If target=None, then predicted durations are applied"""
    reps = torch.round(durations.float() * pace).long()
    dec_lens = reps.sum(dim=1)

    enc_rep = pad_sequence(
        [torch.repeat_interleave(o, r, dim=0) for o, r in zip(enc_out, reps)],
        batch_first=True)
    if mel_max_len:
        enc_rep = enc_rep[:, :mel_max_len]
        dec_lens = torch.clamp_max(dec_lens, mel_max_len)
    return enc_rep, dec_lens