def loss(self, policy, R, V, actions_onehot): Advantage = R - V value_loss = 0.5 * torch.mean(torch.square(Advantage)) log_policy = torch.log(torch.clip(policy, 1e-6, 0.999999)) log_policy_actions = torch.sum(log_policy * actions_onehot, dim=1) policy_loss = torch.mean(-log_policy_actions * Advantage.detach()) entropy = torch.mean(torch.sum(policy * -log_policy, dim=1)) loss = policy_loss + self.value_coeff * value_loss - self.entropy_coeff * entropy return loss
def forward(self, z, inverse_mode=False, context=None): if inverse_mode: squashed = 2 * (z - self.low) / self.range - 1 return torch.atanh( torch.clip(z, -1 + SMALL_NUMBER, 1 - SMALL_NUMBER)), -self._log_abs_det_jac(squashed) else: squashed = torch.tanh(z) squashed01 = (squashed + 1) / 2 return squashed01 * self.range + self.low, self._log_abs_det_jac( squashed)
def forward(self, ref, dist) -> torch.Tensor: """Input could be 4-D or 5-D tensors :param ref: :param dist: :return: """ batch_sz = ref.shape[0] diff = torch.sub(ref, dist) mse: torch.Tensor = torch.mean(torch.square(diff).reshape(batch_sz, -1), dim=1) psnr: torch.Tensor = torch.mul(10, torch.log10(255 ** 2 / mse)) return torch.clip(psnr, 0, 100)
def se3_dist(sample_1, sample_2, beta=1., eps=1e-5): r"""Compute the matrix of all distances between two samples sets, where each sample is an element of SE(3) expressed as a 4x4 matrix. The distance is defined as || sample_1.t - sample_2.t ||_2 + beta * acos( (tr(sample_1.R.T sample_2.R) - 1)/2 ) i.e. the Euclidean translation distance, plus the angle between the poses in radians. Arguments --------- sample_1 : torch.Tensor or Variable The first sample, should be of shape ``(n_1, 4, 4)``. sample_2 : torch.Tensor or Variable The second sample, should be of shape ``(n_2, 4, 4)``. beta : float The weight multiplied by the angle distance to create the total distance. Returns ------- torch.Tensor or Variable Matrix of shape (n_1, n_2). The [i, j]-th entry is equal to the SE(3) distance above between the ith and jth samples.""" assert len(sample_1.shape) == 3 and sample_1.shape[1:] == (4, 4) assert len(sample_2.shape) == 3 and sample_2.shape[1:] == (4, 4) n_1, n_2 = sample_1.size(0), sample_2.size(0) beta = float(beta) # Compute translation distances. sample_1_t = sample_1[:, :3, 3] sample_2_t = sample_2[:, :3, 3] norms_1 = torch.sum(sample_1_t**2, dim=1, keepdim=True) norms_2 = torch.sum(sample_2_t**2, dim=1, keepdim=True) t_norms = (norms_1.expand(n_1, n_2) + norms_2.transpose(0, 1).expand(n_1, n_2)) t_distances_squared = t_norms - 2 * sample_1_t.mm(sample_2_t.t()) sample_t_distances = torch.sqrt(eps + torch.abs(t_distances_squared)) # Compute rotation distances. sample_1_R = sample_1[:, :3, :3] sample_2_R = sample_2[:, :3, :3] # Prepare for a batch matrix multiply to get R1^T R2 terms. sample_1_R_expanded = sample_1_R.transpose(1, 2).unsqueeze(1).expand(n_1, n_2, 3, 3) sample_2_R_expanded = sample_2_R.unsqueeze(1).transpose(0, 1).expand(n_1, n_2, 3, 3) sample_R1tR2 = torch.matmul(sample_1_R_expanded, sample_2_R_expanded) sample_angle_distances = torch.abs(torch.acos( torch.clip( (sample_R1tR2.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1) - 1)/2, -1+eps, 1-eps ) )) return sample_t_distances + sample_angle_distances * beta
def q_sample(self, x_start, t, noise): # get q(x_t|x_0), sampling via the gumbel distribution for categorical distribution, # see https://en.wikipedia.org/wiki/Categorical_distribution zero_idx = t<0 t[zero_idx] = 0 logits = torch.log(self.q_probs(x_start, t) + self.eps) logits[zero_idx] = 0. noise = torch.clip(noise, min=self.eps, max=1.) gumbel_noise = -torch.log(-torch.log(noise)) # noise~Uniform(0, 1) return torch.argmax(logits + gumbel_noise, dim=-1)
def cosine_beta_schedule(timesteps, s = 0.008): """ cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ """ steps = timesteps + 1 x = torch.linspace(0, steps, steps) alphas_cumprod = torch.cos(((x / steps) + s) / (1 + s) * torch.pi * 0.5) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 0, 0.999)
def write_image(self, image: torch.Tensor, img_caption: str = "sample_image", step: int = 0): image = torch.clip(tr.inv_norm(image).to(torch.float), 0, 1) # [-1, 1] -> [0, 1] image *= 255. # [0, 1] -> [0, 255] image = image.permute(1, 2, 0).to(dtype=torch.uint8) self.writer.add_image(img_caption, image, step, dataformats='HWC') self.writer.flush()
def _predict(self, batch: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]: if 'x' in batch.keys(): # Style-content separate xp: Tensor = self._scaler.scaling(batch['x']) z: List[Tensor] = self._glow(xp) c, s = self._latent_to_cs(z) output: Dict[str, List[Tensor]] = {'z': z, 's': s, 'c': c} else: assert ('x1' in batch.keys() and 'x2' in batch.keys()) # Style change xp1: Tensor = self._scaler.scaling(batch['x1']) xp2: Tensor = self._scaler.scaling(batch['x2']) z1: List[Tensor] = self._glow(xp1) z2: List[Tensor] = self._glow(xp2) c1, s1 = self._latent_to_cs(z1) c2, s2 = self._latent_to_cs(z2) xp1_idt: Tensor = self._glow.reverse(z1) xp2_idt: Tensor = self._glow.reverse(z2) x1_idt: Tensor = torch.clip(self._scaler.unscaling(xp1_idt), 0., 1.) x2_idt: Tensor = torch.clip(self._scaler.unscaling(xp2_idt), 0., 1.) xp12: Tensor = self._glow.reverse(self._cs_to_latent(c1, s2)) xp21: Tensor = self._glow.reverse(self._cs_to_latent(c2)) #xp21: Tensor = self._glow.reverse(self._cs_to_latent(c2, s1)) x12: Tensor = torch.clip(self._scaler.unscaling(xp12), 0., 1.) x21: Tensor = torch.clip(self._scaler.unscaling(xp21), 0., 1.) c12, s12 = self._latent_to_cs(self._glow(xp12)) c21, s21 = self._latent_to_cs(self._glow(xp21)) xp1_cycle: Tensor = self._glow.reverse(self._cs_to_latent(c12)) #xp1_cycle: Tensor = self._glow.reverse(self._cs_to_latent(c12, s1)) xp2_cycle: Tensor = self._glow.reverse(self._cs_to_latent(c21, s2)) x1_cycle: Tensor = torch.clip(self._scaler.unscaling(xp1_cycle), 0., 1.) x2_cycle: Tensor = torch.clip(self._scaler.unscaling(xp2_cycle), 0., 1.) xp_sample: Tensor = self._glow.sample(len(batch['x1']), device=batch['x1'].device) x_sample: Tensor = torch.clip(self._scaler.unscaling(xp_sample), 0., 1.) output: Dict[str, Tensor] = { 'z1': flatten_latent_vector(z1), 'z2': flatten_latent_vector(z2), 's1': flatten_latent_vector(s1), 's2': flatten_latent_vector(s2), 'c1': flatten_latent_vector(c1), 'c2': flatten_latent_vector(c2), 'x1_idt': x1_idt, 'x2_idt': x2_idt, 'x12': x12, 'x21': x21, 'x1_cycle': x1_cycle, 'x2_cycle': x2_cycle, 'x_sample': x_sample } return output
def learn(self): self.train_it += 1 s, a, s_, r, done = self.memory.get_sample() with torch.no_grad(): # Select action according to policy and add clipped noise noise = torch.randn_like(a) * self.policy_noise noise = torch.clip(noise, -self.noise_clip, self.noise_clip) a_ = self.actor_target(s_) + noise a_ = torch.clip(a_, -1., 1.) # Compute the target Q value target_Q1, target_Q2 = self.critic_target(s_, a_) target_Q = torch.min(target_Q1, target_Q2) td_target = r + (1 - done) * self.gamma * target_Q # update critic q1, q2 = self.critic(s, a) critic_loss = F.mse_loss(q1, td_target) + F.mse_loss(q2, td_target) self.opt_critic.zero_grad() critic_loss.backward() self.opt_critic.step() if self.train_it % self.policy_freq == 0: # update actor # 两种写法都是可行的,可以直接用一个,也可以取min q1, q2 = self.critic(s, self.actor(s)) q = torch.min(q1, q2) # q = self.critic.Q1(s, self.actor(s)) actor_loss = -torch.mean(q) self.opt_actor.zero_grad() actor_loss.backward() self.opt_actor.step() # update target network self.soft_update(self.critic_target, self.critic) self.soft_update(self.actor_target, self.actor) # update varaiance self.var = max(self.var * self.var_decay, self.var_min)
def cwpdis_ope(pibs, pies, rewards, length, max_time=MAX_TIME): # this computes a consistent weighted per-decision IS # following POPCORN paper / Thomas' thesis n = pibs.shape[0] weights = torch.ones((n, max_time)) wis_weights = torch.ones(n) for i in range(n): last = 1 for t in range(int(length[i])): assert pibs[i, t] != 0 # changed these two lines so gradient can flow... last = last * (pies[i, t] / pibs[i, t]) weights[i, t] = last wis_weights[i] = last with torch.no_grad(): masks = (weights != 0).detach().numpy() masks = torch.FloatTensor(masks) weights = torch.clip(weights, 1e-16, 1e3) weights_norm = weights.sum(dim=0) # step 1: \sum_n r_nt * w_nt weighted_r = (rewards * weights).sum(dim=0) # step 2: (\sum_n r_nt * w_nt) / \sum_n w_nt score = weighted_r / weights_norm # step 3: \sum_t ((\sum_n r_nt * w_nt) / \sum_n w_nt) score = score.sum() # sum through the trajectory, we get CWPDIS(θ), and ESS(θ) wis_weights = torch.clip(wis_weights, 1e-16, 1e3) weights_norm = wis_weights.sum(dim=0) wis_weights = wis_weights / weights_norm return score, wis_weights
def save_vp(self): if not self.imgs is None: theta = float(self.ui.lineEdit_theta.text()) phi = float(self.ui.lineEdit_phi.text()) tp_vec = torch.Tensor([[phi, theta], [phi, theta]]).type( torch.float32).to('cuda:0').contiguous() y = self.vp(self.imgs, tp_vec) y = torch.clip(y, 0, 255) ref, tar = convert_img(y[0]), convert_img(y[1]) pd = self.ui.slineEdit.text() cv2.imwrite('{}/ref_vp.png'.format(pd), ref) cv2.imwrite('{}/tar_vp.png'.format(pd), tar)
def sample(self, batch_latent, batch_label): ''' :param batch_latent: a tensor of size (batch_size, self.latent_size) :param batch_label: a tensor of size (batch_size, self.label_dim) :return: a tensor of size (batch_size, C, H, W), each value is in range [0, 1] ''' with torch.no_grad(): # TODO: get samples from the decoder. y = self.decode(batch_latent, batch_label) # y += self.prior.sample(y.shape).to(y.device) y = torch.clip(y, 0, 1) return y
def jitter_point_cloud(xyz, sigma=0.001, prob=0.95): """ Randomly jitter point heights. Input: Nx3 array, original point clouds Return: Nx3 array, jittered point clouds """ if torch.rand([]) < prob: noise = torch.randn(xyz.shape) * sigma noise = torch.clip(noise, min=-3 * sigma, max=3 * sigma) xyz += noise return xyz
def jitter_color(color, sigma=0.05, prob=0.95): """ Randomly jitter colors. Input: Nx3 array, original point colors Return: Nx3 array, jittered point colors """ if torch.rand([]) < prob: noise = torch.randn(color.shape) * sigma color += noise color = torch.clip(color, min=0., max=1.) return color
def shift_color(color, shift_range=0.1, prob=0.95): """ Randomly shift color. Input: Nx3 array, original point colorss Return: Nx3 array, shifted point colors """ if torch.rand([]) < prob: shifts = torch.rand([3]) * shift_range color += shifts color = torch.clip(color, min=0., max=1.) return color
def _step_symplectic(self, func, y, t, h): dy = torch.zeros(y.size(), dtype=self.dtype, device=self.device) n = y.size(-1) // 2 dy[..., n:] = y[..., :n] - y[..., n:] k_ = func(t + self.eps, y[..., :n]) sin_q_delta = torch.sin(y[..., :n] - y[..., n:]) + (h**2) * k_ dy[..., :n] = torch.arcsin(torch.clip(sin_q_delta, -(1.-1e-4), 1-1e-4)) return dy
def score(self, loss_dict): # TODO: Scale each loss and combine them to output an anomaly score per sample scaled_loss = {} for attr in loss_dict: # TODO: scale loss_dict[attr] before assigning it scaled_loss[attr] = (loss_dict[attr] - self.scorer_parameters[attr]["mu"] ) / self.scorer_parameters[attr]["sigma"] scaled_loss[attr] = torch.clip(scaled_loss[attr], 0, 10) return torch.sum(torch.stack([v for k, v in scaled_loss.items()]), dim=0)
def generate_pos_field(pos, size): h = torch.arange(size[0]).reshape( (1, 1, -1)).repeat(1, size[1], 1) / size[0] w = torch.arange(size[1]).reshape( (1, -1, 1)).repeat(1, 1, size[0]) / size[1] h -= pos[0] w -= pos[1] field = torch.vstack((h, w)) field = field / torch.clip(torch.norm(field, dim=0), min=0.01) return field
def forward(self, predicted_score, true_score, n=None): # score_diff = predicted_score - true_score score_diff = predicted_score * true_score loss = self.threshold - score_diff loss = torch.clip(loss, min=0) loss = torch.square(loss) if not self.weight is None: loss = loss * self.weight return 0.5 * loss.mean()
def forward(self, x, clip=None): x_key_padding_mask = (x == 0).clone().detach( ) # zero out the attention of empty sequence elements x = self.embedding(x.transpose(1, 0).int()) # [seq, batch] for i in range(self.encoder_layers): # Self-attention block residue = x.clone() x = self.encoder_norms1[i](x) if not self.relative_attention: x = self.self_attn_layers[i]( x, x, x, key_padding_mask=x_key_padding_mask)[0] else: x = self.self_attn_layers[i](x.transpose(1, 0)).transpose( 1, 0 ) # pairwise relative position encoding embedded in the self-attention block x = self.encoder_dropouts[i](x) x = x + residue # dense block residue = x.clone() x = self.encoder_linear1[i](x) x = self.encoder_norms2[i](x) x = self.encoder_activations[i](x) x = self.encoder_linear2[i](x) x = x + residue if self.aggregation_mode == 'mean': x = x.mean(dim=0) # mean aggregation elif self.aggregation_mode == 'sum': x = x.sum(dim=0) # sum aggregation elif self.aggregation_mode == 'max': x = x.max(dim=0) # max aggregation else: print(self.aggregation_mode + ' is not a valid aggregation mode!') for i in range(self.decoder_layers): if i != 0: residue = x.clone() x = self.decoder_linear[i](x) x = self.decoder_norms[i](x) x = self.decoder_dropouts[i](x) x = self.decoder_activations[i](x) if i != 0: x += residue x = self.output_layer(x) if clip is not None: x = torch.clip(x, max=clip) return x
def generate_permutation_for_numerical_all_dim(input_data: torch.tensor, num_samples, variance=0.1, clip_permutation: bool = True): ''' [input_data]: Normalised data. should be a 1-D tensor. -------------------------- Return: all permutations. ''' if clip_permutation: max_range = torch.clip(input_data + variance, -0.999999, 1).float() min_range = torch.clip(input_data - variance, -1, 0.999999).float() else: max_range = input_data + variance min_range = input_data - variance TempRecord.input_data = input_data TempRecord.max_range = max_range TempRecord.min_range = min_range dist = Uniform(min_range, max_range) return dist.sample((num_samples, ))
def p_sample(self, model_fn, x, t, noise): model_logits, pred_x_start_logits = self.p_logits(model_fn=model_fn, x=x, t=t) assert noise.shape == model_logits.shape, noise.shape nonzero_mask = (t != 0).reshape(x.shape[0], *((1, ) * len(x.shape))).float() noise = torch.clip(noise, self.eps, 1.) gumbel_noise = -torch.log(-torch.log(noise)) sample = torch.argmax(model_logits + nonzero_mask * gumbel_noise, dim=-1) return sample, torch.softmax(pred_x_start_logits, dim=-1)
def compute_strong_fwd_bwd_loss(self, y_fwd, y_bwd, targets): if self.label_smoothing > 0.: targets = torch.clip(targets, min=self.label_smoothing, max=1 - self.label_smoothing) strong_targets_fwd = torch.cummax(targets, dim=-1)[0] strong_targets_bwd = torch.cummax(targets.flip(-1), dim=-1)[0].flip(-1) loss = nn.BCELoss(reduction='none')(y_fwd, strong_targets_fwd) if y_bwd is not None: loss = ( loss / 2 + nn.BCELoss(reduction='none')(y_bwd, strong_targets_bwd) / 2) return loss
def slice_imgs(imgs, count, size=224, transform=None, align='uniform', micro=1.): def map(x, a, b): return x * (b-a) + a rnd_size = torch.rand(count) if align == 'central': # normal around center rnd_offx = torch.clip(torch.randn(count) * 0.2 + 0.5, 0., 1.) rnd_offy = torch.clip(torch.randn(count) * 0.2 + 0.5, 0., 1.) else: # uniform rnd_offx = torch.rand(count) rnd_offy = torch.rand(count) sz = [img.shape[2:] for img in imgs] sz_max = [torch.min(torch.tensor(s)) for s in sz] if align == 'overscan': # add space sz = [[2*s[0], 2*s[1]] for s in list(sz)] imgs = [pad_up_to(imgs[i], sz[i], type='centr') for i in range(len(imgs))] sliced = [] for i, img in enumerate(imgs): cuts = [] sz_max_i = max(size, 0.25*sz_max[i]) if micro is True else sz_max[i] if micro is True: # both scales, micro mode sz_min_i = size//4 elif micro is False: # both scales, macro mode sz_min_i = 0.5*sz_max[i] else: # single scale sz_min_i = size if torch.rand(1) < micro else 0.9*sz_max[i] for c in range(count): csize = map(rnd_size[c], sz_min_i, sz_max_i).int() offsetx = map(rnd_offx[c], 0, sz[i][1] - csize).int() offsety = map(rnd_offy[c], 0, sz[i][0] - csize).int() cut = img[:, :, offsety:offsety + csize, offsetx:offsetx + csize] cut = F.interpolate(cut, (size,size), mode='bicubic', align_corners=False) # bilinear if transform is not None: cut = transform(cut) cuts.append(cut) sliced.append(torch.cat(cuts, 0)) return sliced
def forward(self, img, mask): x = torch.cat([mask, img], dim=1) x = self.head(x) attn = self.body_attn_1(x) attn = self.body_attn_2(attn) attn = self.body_attn_3(attn) attn = self.body_attn_4(attn) attn = self.body_attn_attn(attn, attn, mask) attn = self.body_attn_5(attn) attn = self.body_attn_6(attn) conv = self.body_conv(x) x = self.tail(torch.cat([conv, attn], dim=1)) return torch.clip(x, -1, 1)
def update_beliefs(self, b, t, response_outcomes, mask=None): if mask is None: mask = ones(self.runs) mu = [] pi = [] mu_pre = self.mu[-1] pi_pre = self.pi[-1] # 1st level # Prediction muhat = mu_pre[..., 0].sigmoid() #Update # 2nd level # Precision of prediction w1 = torch.exp(self.kappa*mu_pre[..., -1]) # Updates pihat1 = pi_pre[..., 0]/(1 + pi_pre[..., 0] * w1) pi1 = pihat1 + mask * muhat * (1 - muhat) o = response_outcomes[-1][:, -2] # observation wda = mask * ( (o + 1)/2 - muhat) / pi1 mu.append(mu_pre[..., 0] + wda) pi.append(pi1) # Volatility prediction error da1 = mask * ((1 / pi1 + (wda)**2) * pihat1 - 1) # 3rd level # Precision of prediction pihat2 = pi_pre[..., 1] / (1. + pi_pre[..., 1] * self.eta) # Weighting factor w2 = w1 * pihat1 * mask # Updates pi2 = torch.clip(pihat2 + self.kappa**2 * w2 * (w2 + (2 * w2 - 1) * da1) / 2, min=1e-2) mu.append(mu_pre[..., 1] + self.kappa * w2 * da1 / (2 * pi2)) pi.append(pi2) self.mu.append(torch.stack(mu, dim=-1)) self.pi.append(torch.stack(pi, dim=-1))
def random_crop( img, num_crops, crop_size=224, normalize=True, ): def map(x, a, b): return x * (b - a) + a rnd_size = torch.rand(num_crops) rnd_offx = torch.clip(torch.randn(num_crops) * 0.2 + 0.5, 0., 1.) rnd_offy = torch.clip(torch.randn(num_crops) * 0.2 + 0.5, 0., 1.) img_size = img.shape[2:] min_img_size = min(img_size) sliced = [] cuts = [] for c in range(num_crops): current_crop_size = map(rnd_size[c], crop_size, min_img_size).int() offsetx = map(rnd_offx[c], 0, img_size[1] - current_crop_size).int() offsety = map(rnd_offy[c], 0, img_size[0] - current_crop_size).int() cut = img[:, :, offsety:offsety + current_crop_size, offsetx:offsetx + current_crop_size] cut = F.interpolate( cut, (crop_size, crop_size), mode='bicubic', align_corners=False, ) # bilinear if normalize is not None: cut = img_norm(cut) cuts.append(cut) return torch.cat(cuts, axis=0)
def evaluate(self): """ Evaluate the model on the validation set Returns: dict: evaluation metrics """ self.model.eval() sigmoid = nn.Sigmoid() loss_val, top_val, num_samples = 0, 0, 0 for x, target in self.val_loader: x, target = self.to_cuda(x, target) # Forward out = self.model(x) # Loss computation loss_val += 1 - self.criterion(out, target).item() num_samples += x.shape[0] self.val_loss_recorder.append(loss_val / num_samples) loss_val /= len(self.val_loader) acc_val = 0 inv_normalize = transforms.Normalize( mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225], std=[1 / 0.229, 1 / 0.224, 1 / 0.225]) tp = transforms.ToPILImage() res = tp(torch.clip(inv_normalize(out[0, :, :, :]), 0, 1)) res.save('res_' + str(self.epoch).zfill(4) + '.jpg') tar = tp(torch.clip(inv_normalize(target[0, :, :, :]), 0, 1)) tar.save('tar_' + str(self.epoch).zfill(4) + '.jpg') return dict(train_loss=self.train_loss, loss_val=loss_val, acc_val=acc_val)
def show_side_by_side_loss(original, reconstructed): """Shows two images side by side, and shows the MSE loss above it. Usefull for autoencoder validation""" batchsize = original.shape[0] original = torch.clip(original, 0, 1).detach().cpu() reconstructed = torch.clip(reconstructed, 0, 1).detach().cpu() for i in range(batchsize): fig, axs = plt.subplots(1, 2, figsize=(10, 20)) fig.tight_layout() mseloss = torch.nn.functional.mse_loss(original[i], reconstructed[i]) print("The MSE loss is: ", mseloss.item()) axs[0].imshow(original[i].permute(1, 2, 0)) axs[1].imshow(reconstructed[i].permute(1, 2, 0)) axs[0].tick_params(left=False, bottom=False, labelleft=False, labelbottom=False) axs[1].tick_params(left=False, bottom=False, labelleft=False, labelbottom=False) plt.show() print("\n\n\n")
def forward(self, img, seg): """ Args: img (Tensor): Input image. Returns: Tensor: Color jittered image. """ assert isinstance( img, t.Tensor ), "BUG CHECK: Only 'torch.tensor' type of 'img' is supported." fn_idx = t.randperm(4) for fn_id in fn_idx: if fn_id == 0 and self.brightness is not None: brightness = self.brightness brightness_factor = t.tensor(1.0).uniform_( brightness[0], brightness[1]).item() img = F.adjust_brightness(img, brightness_factor) if fn_id == 1 and self.contrast is not None: contrast = self.contrast contrast_factor = t.tensor(1.0).uniform_( contrast[0], contrast[1]).item() img = F.adjust_contrast(img, contrast_factor) if fn_id == 2 and self.saturation is not None: saturation = self.saturation saturation_factor = t.tensor(1.0).uniform_( saturation[0], saturation[1]).item() img = F.adjust_saturation(img, saturation_factor) if fn_id == 3 and self.hue is not None: hue = self.hue hue_factor = t.tensor(1.0).uniform_(hue[0], hue[1]).item() hue_factor_radians = hue_factor * 2.0 * np.pi # Prepare rotation matrix cosA = np.cos(hue_factor_radians) sinA = np.sin(hue_factor_radians) hue_rotation_matrix =\ [[cosA + (1.0 - cosA) / 3.0, 1./3. * (1.0 - cosA) - np.sqrt(1./3.) * sinA, 1./3. * (1.0 - cosA) + np.sqrt(1./3.) * sinA], [1./3. * (1.0 - cosA) + np.sqrt(1./3.) * sinA, cosA + 1./3.*(1.0 - cosA), 1./3. * (1.0 - cosA) - np.sqrt(1./3.) * sinA], [1./3. * (1.0 - cosA) - np.sqrt(1./3.) * sinA, 1./3. * (1.0 - cosA) + np.sqrt(1./3.) * sinA, cosA + 1./3. * (1.0 - cosA)]] img = img.permute(1, 2, 0) @ t.as_tensor(hue_rotation_matrix, dtype=img.dtype) img = t.clip(img.permute(2, 0, 1), min=0., max=1.) return img, seg