def correlation_max(dump_list): correlations = correlate_dumps(dump_list) results = [] for corr, idx in correlations: maxs = torch.abs(torch.max(corr, dim=1)) sort, sort_idx = torch.sort(maxs, descending=True) corr = torch.select(corr, sort_idx) idx = torch.select(idx, sort_idx) results.append((maxs, corr, idx))
def hls_to_rgb(image: torch.Tensor) -> torch.Tensor: r"""Convert a HLS image to RGB. The image data is assumed to be in the range of (0, 1). Args: image: HLS image to be converted to RGB with shape :math:`(*, 3, H, W)`. Returns: RGB version of the image with shape :math:`(*, 3, H, W)`. Example: >>> input = torch.rand(2, 3, 4, 5) >>> output = hls_to_rgb(input) # 2x3x4x5 """ if not isinstance(image, torch.Tensor): raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}") if len(image.shape) < 3 or image.shape[-3] != 3: raise ValueError( f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") if not torch.jit.is_scripting(): # weird way to use globals compiling with JIT even in the code not used by JIT... # __setattr__ can be removed if pytorch version is > 1.6.0 and then use: # hls_to_rgb.HLS2RGB = hls_to_rgb.HLS2RGB.to(image.device) hls_to_rgb.__setattr__('HLS2RGB', hls_to_rgb.HLS2RGB.to(image)) # type: ignore _HLS2RGB: torch.Tensor = hls_to_rgb.HLS2RGB # type: ignore else: _HLS2RGB = torch.tensor([[[0.0]], [[8.0]], [[4.0]]], device=image.device, dtype=image.dtype) # 3x1x1 im: torch.Tensor = image.unsqueeze(-4) h: torch.Tensor = torch.select(im, -3, 0) l: torch.Tensor = torch.select(im, -3, 1) s: torch.Tensor = torch.select(im, -3, 2) h = h * (6 / math.pi) # h * 360 / (2 * math.pi) / 30 a = s * torch.min(l, 1.0 - l) # kr = (0 + h) % 12 # kg = (8 + h) % 12 # kb = (4 + h) % 12 k: torch.Tensor = (h + _HLS2RGB) % 12 # l - a * max(min(min(k - 3.0, 9.0 - k), 1), -1) mink = torch.min(k - 3.0, 9.0 - k) return torch.addcmul(l, a, mink.clamp_(min=-1.0, max=1.0), value=-1)
def as_identity(self): # view_as_real and view_as_complex behavior should be like an identity def func(z): z_ = torch.view_as_complex(z) z_select = torch.select(z_, z_.dim() - 1, 0) z_select_real = torch.view_as_real(z_select) return z_select_real.sum() z = torch.randn(10, 2, 2, dtype=torch.double, requires_grad=True) gradcheck(func, [z]) func(z).backward() z1 = z.clone().detach().requires_grad_(True) torch.select(z1, z1.dim() - 2, 0).sum().backward() self.assertEqual(z.grad, z1.grad)
def extract_map_statedict( m_b: Union[ALEBOGP, ModelListGP], num_outputs: int ) -> List[MutableMapping[str, Tensor]]: """Extract MAP statedict from the batch-mode ALEBO GP. The batch GP can be either a single ALEBO GP or a ModelListGP of ALEBO GPs. Args: m_b: Batch-mode GP. num_outputs: Number of outputs being modeled. """ is_modellist = num_outputs > 1 map_sds: List[MutableMapping[str, Tensor]] = [ OrderedDict() for i in range(num_outputs) ] sd = m_b.state_dict() for k, v in sd.items(): # Extract model index and parameter name if is_modellist: g = re.match(r"^models\.([0-9]+)\.(.*)$", k) if g is None: raise Exception( "Unable to parse ModelList structure" ) # pragma: no cover model_idx = int(g.group(1)) param_name = g.group(2) else: model_idx = 0 param_name = k map_sds[model_idx][param_name] = torch.select(v, 0, 0) return map_sds
def move(self, x: Tensor, es: Tensor, z: Tensor) -> Tensor: """ Args: x: [node, batch, step, dim] es: [2, E] z: [E, batch, K] Return: x: [node, batch, step, dim], future node states """ # z: [E, batch, K] -> [E, batch, step, K] z = z.repeat(x.size(2), 1, 1, 1).permute(1, 2, 0, 3).contiguous() msg, col, size = self.message(x, es) idx = 1 if self.skip_first else 0 norm = len(self.msgs) if self.skip_first: norm -= 1 msgs = sum(self.msgs[i](msg) * torch.select(z, -1, i).unsqueeze(-1) / norm for i in range(idx, len(self.msgs))) # aggregate all msgs from the incoming edges msgs = self.aggregate(msgs, col, size, 'add') # skip connection h = torch.cat([x, msgs], dim=-1) # predict the change in states delta = self.out(h) return x + delta
def forward(self, *x): Is, Ys, Iq, label = x N, way, shot = Is.size()[:3] # Is: (N, way, shot, 3, h, w) -> (N * way * shot, 3, h, w) Is = merge_first_k_dim(Is, dims=(0, 1, 2)) # guarantee that the size of `label` is (N, way) if way == 1: label = label.unsqueeze(-1) out_s = self.backbone(Is)['layer3'] Fs = self.embedding(out_s) Fs = self.deeplab_head(Fs, relu=True) # Fs: (N * way * shot, c, h, w) -> (N * way, shot, 3, h, w) Fs = depart_first_dim(Fs, dims=(N * way, shot)) if self.visualize and way == 1: self.vis.update({ 'hidden_Fs': torch.select(Fs, dim=1, index=0).clone().detach() }) # Fs: (N * way, shot, 3, h, w) -> (N, way, shot, 3, h, w) Fs = depart_first_dim(Fs, dims=(N, way)) # forward query out_q = self.backbone(Iq)['layer3'] Fq = self.embedding(out_q) Fq = self.deeplab_head(Fq, relu=True) if self.visualize: self.vis.update({'hidden_Fq': Fq.clone().detach()}) # get knowledge logits logits_full = self.estimator(Fq, Fs, Ys, label, mode='training') # get pattern logits if way > 1: logits_binary = [] for i in range(way): l = torch.select(label, dim=1, index=i) logits_binary.append(get_binary_logits(logits_full, l)) logits_binary = torch.stack(logits_binary, dim=1) logits_binary = merge_first_k_dim(logits_binary, (0, 1)) else: label = torch.select(label, dim=1, index=0) logits_binary = get_binary_logits(logits_full, label, base=True) return logits_full, logits_binary
def forward(self, x, adj, idx, norm): msgs = Variable( torch.zeros(x.size(0), x.size(1), x.size(2), self.msg_out)) if cfg.gpu: msgs = msgs.cuda() for i in range(idx, len(self.msgs)): msg = self.msgs[i](x) h = msg * torch.select(adj, -1, i).unsqueeze(-1) out = [h] for j in range(self.gdep): h = self.alpha * h + (1 - self.alpha) * h * torch.select( adj, -1, i).unsqueeze(-1) out.append(h) out = torch.cat(out, dim=-1) out = self.mlp(out) out = out / norm msgs += out return msgs
def tensor_indexing_ops(self): x = torch.randn(2, 4) y = torch.randn(4, 4) t = torch.tensor([[0, 0], [1, 0]]) mask = x.ge(0.5) i = [0, 1] return len( torch.cat((x, x, x), 0), torch.concat((x, x, x), 0), torch.conj(x), torch.chunk(x, 2), torch.dsplit(torch.randn(2, 2, 4), i), torch.column_stack((x, x)), torch.dstack((x, x)), torch.gather(x, 0, t), torch.hsplit(x, i), torch.hstack((x, x)), torch.index_select(x, 0, torch.tensor([0, 1])), x.index(t), torch.masked_select(x, mask), torch.movedim(x, 1, 0), torch.moveaxis(x, 1, 0), torch.narrow(x, 0, 0, 2), torch.nonzero(x), torch.permute(x, (0, 1)), torch.reshape(x, (-1, )), torch.row_stack((x, x)), torch.select(x, 0, 0), torch.scatter(x, 0, t, x), x.scatter(0, t, x.clone()), torch.diagonal_scatter(y, torch.ones(4)), torch.select_scatter(y, torch.ones(4), 0, 0), torch.slice_scatter(x, x), torch.scatter_add(x, 0, t, x), x.scatter_(0, t, y), x.scatter_add_(0, t, y), # torch.scatter_reduce(x, 0, t, reduce="sum"), torch.split(x, 1), torch.squeeze(x, 0), torch.stack([x, x]), torch.swapaxes(x, 0, 1), torch.swapdims(x, 0, 1), torch.t(x), torch.take(x, t), torch.take_along_dim(x, torch.argmax(x)), torch.tensor_split(x, 1), torch.tensor_split(x, [0, 1]), torch.tile(x, (2, 2)), torch.transpose(x, 0, 1), torch.unbind(x), torch.unsqueeze(x, -1), torch.vsplit(x, i), torch.vstack((x, x)), torch.where(x), torch.where(t > 0, t, 0), torch.where(t > 0, t, t), )
def get_time_from_intervals(I): last_dim_size = I.shape[-1] I = F.pad( I, (1, 0) ) #pads the last dimension with width=1 on the left and width=0 on the right. Value of padding=0 I = I.cumsum(dim=-1) # Note: new I has one element extra than initial I on the final dimension. This cuts that one last element in a general way I = torch.stack( [torch.select(I, I.dim() - 1, i) for i in range(last_dim_size)], dim=-1) return I
def objective(p, Theta, Tx, Rx, sigma2): train_sample_num, L, K, Ml, N, _ = Tx.shape M = Rx.shape[3] Tx_matrix = Tx.permute(0, 1, 3, 2, 4, 5).contiguous().view(train_sample_num, L*Ml, K*N, 2) Rx_matrix = Rx.permute(0, 1, 3, 2, 4, 5).contiguous().view(train_sample_num, K*M, L*Ml, 2) Theta_matrix = Theta.view(train_sample_num, L*Ml, 2) Tx_real, Tx_imag = torch.select(Tx_matrix, -1, 0), torch.select(Tx_matrix, -1, 1) Rx_real, Rx_imag = torch.select(Rx_matrix, -1, 0), torch.select(Rx_matrix, -1, 1) Theta_real, Theta_imag = torch.select(Theta_matrix, -1, 0).diag_embed(), torch.select(Theta_matrix, -1, 1).diag_embed() Rx_Theta_real = torch.matmul(Rx_real, Theta_real) - torch.matmul(Rx_imag, Theta_imag) Rx_Theta_imag = torch.matmul(Rx_real, Theta_imag) + torch.matmul(Rx_imag, Theta_real) h_real = torch.matmul(Rx_Theta_real, Tx_real) - torch.matmul(Rx_Theta_imag, Tx_imag) h_imag = torch.matmul(Rx_Theta_real, Tx_imag) + torch.matmul(Rx_Theta_imag, Tx_real) # Theta_Tx_real = torch.matmul(Theta_real, Tx_real) - torch.matmul(Theta_imag, Tx_imag) # Theta_Tx_imag = torch.matmul(Theta_imag, Tx_real) + torch.matmul(Theta_real, Tx_imag) # h_real = torch.matmul(Rx_real, Theta_Tx_real) - torch.matmul(Rx_imag, Theta_Tx_imag) # h_imag = torch.matmul(Rx_real, Theta_Tx_imag) + torch.matmul(Rx_imag, Theta_Tx_real) # return torch.mean(h_real) # return torch.mean( Rx_Theta_real ) + torch.mean( Rx_Theta_imag ) + torch.mean( Theta_Tx_real ) + torch.mean( Theta_Tx_imag ) h_square = h_real**2 + h_imag**2 h_gain = p.view(train_sample_num, -1, K*N)*h_square numerator = h_gain.diagonal(dim1 = 1, dim2 = 2) denominator = ( h_gain - numerator.diag_embed() ).sum(dim = 2) + sigma2 return torch.sum(torch.log(1 + numerator / denominator), dim = 1)/math.log(2)
def forward(self, x): num_samples, num_segments, num_features = x.shape hk = torch.zeros((num_samples, self.out_features), dtype=x.dtype, device=x.device) alpha = self.alpha hiddens = [] for k in range(num_segments): xk = torch.select(x, 1, k) hk = alpha * hk + (1 - alpha) * self.linear(xk) hiddens.append(hk) hiddens = torch.stack(hiddens, dim=1) return hk, hiddens
def move(self, x, es, z, h_att, h_node=None, h_edge=None): """ Args: x: [node, batch, step, dim] es: [2, E] z: [E, batch, K] h_att: [step_att, node, batch, step, dim], historical hidden states of nodes used for temporal attention h_node: [node, batch, step, dim], hidden states of nodes, default: None h_edge: [E, batch, step, dim], hidden states of edges, default: None Return: x: [node, batch, step, dim], future node states h_att: [step_att + 1, node, batch, step, dim], accumulated historical hidden states of nodes used for temporal attention msgs: [E, batch, step, dim], hidden states of edges cat: [node, batch, step, dim], hidden states of nodes """ x_emb = self.input_emb(x) x_emb = torch.cat([x_emb, x], dim=-1) # z: [E, batch, K] -> [E, batch, step, K] z = z.repeat(x.size(2), 1, 1, 1).permute(1, 2, 0, 3).contiguous() msg, col, size = self.message(x_emb, es) idx = 1 if self.skip_first else 0 norm = len(self.msgs) if self.skip_first: norm -= 1 msgs = sum(self.msgs[i](msg) * torch.select(z, -1, i).unsqueeze(-1) / norm for i in range(idx, len(self.msgs))) if h_edge is not None: msgs = self.gru_edge(msgs, h_edge) # aggregate all msgs to receiver msg = self.aggregate(msgs, col, size) cat = torch.cat([x, msg], dim=-1) if h_node is None: delta = self.out(cat) h_att = cat.unsqueeze(0) else: cat = self.gru_node(cat, h_node) cur_hidden, _ = self.temporalAttention(h_att, cat) h_att = torch.cat([h_att, cur_hidden.unsqueeze(0)], dim=0) delta = self.out(cur_hidden) return x + delta, h_att, cat, msgs
def move(self, x: Tensor, es: Tensor, z: Tensor, h_node: Tensor = None, h_edge: Tensor = None): """ Args: x: [node, batch, step, dim] es: [2, E] z: [E, batch, K] h_node: [node, batch, step, dim], hidden states of nodes, default: None h_edge: [E, batch, step, dim], hidden states of edges, default: None Return: x: [node, batch, step, dim], future node states msgs: [E, batch, step, dim], hidden states of edges cat: [node, batch, step, dim], hidden states of nodes """ # z: [E, batch, K] -> [E, batch, step, K] z = z.repeat(x.size(2), 1, 1, 1).permute(1, 2, 0, 3).contiguous() msg, col, size = self.message(x, es) idx = 1 if self.skip_first else 0 norm = len(self.msgs) if self.skip_first: norm -= 1 msgs = sum(self.msgs[i](msg) * torch.select(z, -1, i).unsqueeze(-1) / norm for i in range(idx, len(self.msgs))) if h_edge is not None and self.option in {'edge', 'both'}: msgs = self.gru_edge(msgs, h_edge) # aggregate all msgs from the incoming edges msg = self.aggregate(msgs, col, size) # skip connection cat = torch.cat([x, msg], dim=-1) if h_node is not None and self.option in {'node', 'both'}: cat = self.gru_node(cat, h_node) delta = self.out(cat) if self.option == 'node': msgs = None if self.option == 'edge': cat = None return x + delta, cat, msgs
print(x.expand(-1, 4)) # -1 means not changing the size of that dimension """re-shape""" x = torch.randn(3, 2, 1) print(x.shape) print(torch.movedim(x, 1, 0).shape) print('unflatten', x.unflatten(1, (1, 2)).shape) # torch.Size([3, 1, 2, 1]), 第2维变成 1*2 print(t.view(16).shape) # torch.Size([16]) print(t.view(-1, 8).shape) # torch.Size([2, 8]), -1表示该维度计算来 print(torch.flatten(t)) # print(torch.ravel(t)) print(t) print(torch.narrow(t, 0, 0, 2)) # 按行,取前两行 print(torch.narrow(t, 1, 0, 2)) # 按列,取前两列 print(torch.select(t, 0, 1)) # 按行,取第两行 print(torch.select(t, 1, 1)) # 按列,取第两列 print(torch.index_select(t, 0, torch.tensor([0, 2]))) # 按行取 print(torch.index_select(t, 1, torch.tensor([0, 2]))) # 按列取 print('unbind') print(torch.unbind(t, dim=0)) # 按行拆成tuple print(torch.unbind(t, dim=1)) # 按列拆成tuple print(torch.split(t, 2, dim=0)) # 按行拆, 2行一组 print(torch.chunk(t, 2, dim=0)) # 按行拆, 2行一组 # print(torch.tensor_split(torch.arange(8), 3)) """转置""" print(torch.t(t)) # 前两维转置 print(t.T) # 转置, x.permute(n-1, n-2, ..., 0) x = torch.randn(2, 3, 5) print(x.size()) # torch.Size([2, 3, 5])
def forward(self, x_context, y_context, x_target, y_target=None): """ Forward pass on the context and target set Arguments: x_context {torch.Tensor} -- Shape (batch_size, num_context, x_dim) y_context {torch.Tensor} -- Shape (batch_size, num_context, y_dim) x_target {torch.Tensor} -- Shape (batch_size, num_target, x_dim). Assumes this is a superset of x_context. Keyword Arguments: y_target {torch.Tensor} -- Shape (batch_size, num_target, y_dim). Assumes this is a superset of y_context. (default: {None}) Returns: [y_pred_mu, y_pred_sigma, loss] -- [Mean and variance of the predictions for the y_target. The loss if we have y_targets to test against] """ num_batches, num_context, y_dim = y_context.shape _, num_targets, x_dim = x_target.shape # append the channel of ones to the y vector to give it the density channel. # This is the reqired kernel when the multiplicity of x_context is 1 t_grid_i = [] # Loop through the x_dimensions and create a grid uniformly spaced with a # density specified via self.unit_density and a range that definity covers # the range of the targets for i in range(x_dim): # get the x's in the desired dimension x_i = torch.select(x_target, dim=-1, index=i) # find the integer that lower and upper bound the min and max of the # target x's in this dimension. Multiplying by 1.1 to give a bit of extra # room x_min_i = torch.min(x_i) x_max_i = torch.max(x_i) x_range_i = x_max_i - x_min_i x_min_i = torch.floor(x_min_i - 0.1 * x_range_i) x_max_i = torch.ceil(x_max_i + 0.1 * x_range_i) # create a uniform linspace t_grid_i.append( torch.linspace(x_min_i, x_max_i, int(self.unit_density * (x_max_i - x_min_i)))) t_grid = torch.meshgrid(*t_grid_i) t_grid = torch.stack(t_grid, dim=-1) t_grid_shape = t_grid.shape t_grid = t_grid.view(-1, x_dim) # Expand the t_grid to match the number of batches t_grid = t_grid.unsqueeze(0).repeat_interleave(num_batches, dim=0) # Calculate the repersentation function at each location of the grid # Need the transpositions as conv ops take one order of dimensions # and Stheno kernels the opposite. h_grid = kernel_interpolate(y_context, x_context, t_grid, self.kernel_x, keep_density_channel=True) # Concatenate the t_grid locations with the evaluated represnetation # functions # rep = torch.cat((t_grid, h_grid), dim=1) # Pass the representation through the decoder. y_mu_grid, y_sigma_grid = self.rho_cnn( h_grid.transpose(1, 2).view(num_batches, y_dim + 1, *list(t_grid_shape)[:-1])) y_mu_grid = y_mu_grid.view(num_batches, 1, -1).transpose(1, 2) y_sigma_grid = y_sigma_grid.view(num_batches, 1, -1).transpose(1, 2) y_grid = torch.cat((y_mu_grid, y_sigma_grid), dim=-1) y_pred_target = kernel_interpolate(y_grid, t_grid, x_target, self.kernel_rho, keep_density_channel=False) y_pred_target_mu, y_pred_target_sigma = torch.chunk(y_pred_target, 2, dim=-1) # If we have a y_target, then return the loss. Else do not. if y_target is not None: loss = -Normal(y_pred_target_mu, y_pred_target_sigma).log_prob(y_target).mean() else: loss = None return y_pred_target_mu, y_pred_target_sigma, loss, None, None
def rgb_to_hls(image: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: r"""Convert a RGB image to HLS. .. image:: _static/img/rgb_to_hls.png The image data is assumed to be in the range of (0, 1). NOTE: this method cannot be compiled with JIT in pytohrch < 1.7.0 Args: image: RGB image to be converted to HLS with shape :math:`(*, 3, H, W)`. eps: epsilon value to avoid div by zero. Returns: HLS version of the image with shape :math:`(*, 3, H, W)`. Example: >>> input = torch.rand(2, 3, 4, 5) >>> output = rgb_to_hls(input) # 2x3x4x5 """ if not isinstance(image, torch.Tensor): raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}") if len(image.shape) < 3 or image.shape[-3] != 3: raise ValueError( f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") if not torch.jit.is_scripting(): # weird way to use globals compiling with JIT even in the code not used by JIT... # __setattr__ can be removed if pytorch version is > 1.6.0 and then use: # rgb_to_hls.RGB2HSL_IDX = hls_to_rgb.RGB2HSL_IDX.to(image.device) rgb_to_hls.__setattr__( 'RGB2HSL_IDX', rgb_to_hls.RGB2HSL_IDX.to(image)) # type: ignore _RGB2HSL_IDX: torch.Tensor = rgb_to_hls.RGB2HSL_IDX # type: ignore else: _RGB2HSL_IDX = torch.tensor([[[0.0]], [[1.0]], [[2.0]]], device=image.device, dtype=image.dtype) # 3x1x1 # maxc: torch.Tensor # not supported by JIT # imax: torch.Tensor # not supported by JIT maxc, imax = image.max(-3) minc: torch.Tensor = image.min(-3)[0] # h: torch.Tensor # not supported by JIT # l: torch.Tensor # not supported by JIT # s: torch.Tensor # not supported by JIT # image_hls: torch.Tensor # not supported by JIT if image.requires_grad: l_ = maxc + minc s = maxc - minc # weird behaviour with undefined vars in JIT... # scripting requires image_hls be defined even if it is not used :S h = l_ # assign to any tensor... image_hls = l_ # assign to any tensor... else: # define the resulting image to avoid the torch.stack([h, l, s]) # so, h, l and s require inplace operations # NOTE: stack() increases in a 10% the cost in colab image_hls = torch.empty_like(image) h = torch.select(image_hls, -3, 0) l_ = torch.select(image_hls, -3, 1) s = torch.select(image_hls, -3, 2) torch.add(maxc, minc, out=l_) # l = max + min torch.sub(maxc, minc, out=s) # s = max - min # precompute image / (max - min) im: torch.Tensor = image / (s + eps).unsqueeze(-3) # epsilon cannot be inside the torch.where to avoid precision issues s /= torch.where(l_ < 1.0, l_, 2.0 - l_) + eps # saturation l_ /= 2 # luminance # note that r,g and b were previously div by (max - min) r: torch.Tensor = torch.select(im, -3, 0) g: torch.Tensor = torch.select(im, -3, 1) b: torch.Tensor = torch.select(im, -3, 2) # h[imax == 0] = (((g - b) / (max - min)) % 6)[imax == 0] # h[imax == 1] = (((b - r) / (max - min)) + 2)[imax == 1] # h[imax == 2] = (((r - g) / (max - min)) + 4)[imax == 2] cond: torch.Tensor = imax.unsqueeze(-3) == _RGB2HSL_IDX if image.requires_grad: h = torch.mul((g - b) % 6, torch.select(cond, -3, 0)) else: torch.mul((g - b).remainder(6), torch.select(cond, -3, 0), out=h) h += torch.add(b - r, 2) * torch.select(cond, -3, 1) h += torch.add(r - g, 4) * torch.select(cond, -3, 2) # h = 2.0 * math.pi * (60.0 * h) / 360.0 h *= math.pi / 3.0 # hue [0, 2*pi] if image.requires_grad: return torch.stack([h, l_, s], dim=-3) return image_hls
def forward(self, x): a = torch.tensor([[1., 2., 3.], [4., 5., 6.]]) b = a[0:-1] # index relative to the end c = torch.select(a, dim=-1, index=-2) d = torch.select(a, dim=1, index=0) return b + x, c + d
from model.base import MLP, CirConv1d
def run_DENet_COCO_2way(config): ROOT = config['inference']['root'] roots_path = join_path(ROOT, 'COCO') fold = config['inference']['fold'] result_dir = join_path('results', config['model']['arch'], config['inference']['id']) if not os.path.exists(result_dir): os.makedirs(result_dir) CHECKPOINT_DIR = config['inference']['ckpt'] model = get_model(config, checkpoint_dir=CHECKPOINT_DIR) dataset = WrappedCOCOStuff20i(root=roots_path, fold=fold, split='test', way=2, img_size=img_size, transforms=transforms) save_size = config['inference']['save_size'] cnt = 0 while cnt < save_size: rand_idx = random.randint(0, len(dataset) - 1) print('Getting output of DENet from rand_idx: `%d`' % rand_idx) Is, Ys, Iq, Yq, sample_class, Ys_full, Yq_full = dataset[rand_idx] cls_1 = sample_class[0] cls_2 = sample_class[1] Is, Ys, Yq, sample_class, Ys_full = tensor(Is), tensor(Ys), tensor( Yq), tensor(sample_class), tensor(Ys_full) Is, Ys, Iq, Yq, sample_class, Ys_full, Yq_full = \ unsqueeze_zero_dim(Is, Ys, Iq, Yq, sample_class, Ys_full, Yq_full) if torch.cuda.is_available(): Is, Ys, Iq, Yq, sample_class, Ys_full, Yq_full = to_cuda( Is, Ys, Iq, Yq, sample_class, Ys_full, Yq_full) Yq_full_pre, Yq_pre = model(Is, Ys, Iq, sample_class) Is_1 = torch.select(Is, dim=1, index=0) Is_2 = torch.select(Is, dim=1, index=1) Ys_1 = torch.select(Ys, dim=1, index=0) Ys_2 = torch.select(Ys, dim=1, index=1) Yq_1 = torch.select(Yq, dim=1, index=0) Yq_2 = torch.select(Yq, dim=1, index=1) Yq_pre_1 = torch.select(Yq_pre, dim=0, index=0).unsqueeze(0) Yq_pre_2 = torch.select(Yq_pre, dim=0, index=1).unsqueeze(0) cnt += 1 print('cnt / save_size: ', cnt, ' / ', save_size) Yq_pre_1, Yq_pre_2 = upsample(Yq_pre_1, Yq_1), upsample(Yq_pre_2, Yq_2) Yq_full_pre = upsample(Yq_full_pre, Yq_full) print("sample cls: ", cls_1, " ", cls_2) print('Saving result in: ', result_dir) imsave( join_path(result_dir, 'coco_%d_Is_1_2way_cls%d.png' % (rand_idx, cls_1)), query_rgb(Is_1.squeeze())) imsave( join_path(result_dir, 'coco_%d_Is_2_2way_cls%d.png' % (rand_idx, cls_2)), query_rgb(Is_2.squeeze())) imsave(join_path(result_dir, 'coco_%d_Iq_2way.png' % rand_idx), query_rgb(Iq)) imsave(join_path(result_dir, 'coco_%d_Ys_1_2way.png' % rand_idx), mask_gray(Ys_1.squeeze(), pre=False)) imsave(join_path(result_dir, 'coco_%d_Ys_2_2way.png' % rand_idx), mask_gray(Ys_2.squeeze(), pre=False)) imsave(join_path(result_dir, 'coco_%d_Yq_full.png' % rand_idx), mask_color(Yq_full, 21, pre=False)) imsave(join_path(result_dir, 'coco_%d_Yq_pre_1_2way.png' % rand_idx), mask_gray(Yq_pre_1.squeeze(), pre=True)) imsave(join_path(result_dir, 'coco_%d_Yq_pre_2_2way.png' % rand_idx), mask_gray(Yq_pre_2.squeeze(), pre=True)) imsave(join_path(result_dir, 'coco_%d_Yq_full_pre.png' % rand_idx), mask_color(Yq_full_pre.squeeze(), 21, pre=True))
def test(self): # t_idx = random.randint(0, self.val_bs) t_idx = random.randint(0, 5) with torch.no_grad(): self.fixed_randomness() # for reproduction # for writing the total predictions to disk data_idxs = list() all_preds = list() # for ploting P-R Curve predicts = list() truths = list() # for showing predicted samples show_ctxs = list() pred_lbls = list() targets = list() f1_meter = AverageMeter() p_meter = AverageMeter() r_meter = AverageMeter() accuracy_meter = AverageMeter() test_generator = tqdm(enumerate(self.test_loader, 1)) for idx, data in test_generator: self.model.eval() id, label, _, mask, data_idx = data if self.gpu: id, mask, label = self.to_cuda(id, mask, label) pre = self.model((id, mask)) lbl = label.cpu().numpy() yp = pre.argmax(1).cpu().numpy() self.update_metrics(lbl, yp, f1_meter, p_meter, r_meter, accuracy_meter) test_generator.set_description( "Test %d/%d, f1 %.4f, p %.4f, r %.4f, acc %.4f" % (idx, len(self.test_loader), f1_meter.avg, p_meter.avg, r_meter.avg, accuracy_meter.avg) ) data_idxs.append(data_idx.numpy()) all_preds.append(yp) predicts.append(torch.select(pre, dim=1, index=1).cpu().numpy()) truths.append(lbl) # show some of the sample ctx = torch.select(id, dim=0, index=t_idx).detach() ctx = self.model.tokenizer.convert_ids_to_tokens(ctx) ctx = "".join([_ for _ in ctx if _ not in [PAD, CLS]]) yp = yp[t_idx] lbl = lbl[t_idx] show_ctxs.append(ctx) pred_lbls.append(yp) targets.append(lbl) print("*" * 25, " SAMPLE BEGINS ", "*" * 25) for c, t, l in zip(show_ctxs, targets, pred_lbls): print("ctx: ", c, " gt: ", t, " est: ", l) print("*" * 25, " SAMPLE ENDS ", "*" * 25) print("Test, FINAL f1 %.4f, " "p %.4f, r %.4f, acc %.4f\n" % (f1_meter.avg, p_meter.avg, r_meter.avg, accuracy_meter.avg)) # output the final results to disk data_idxs = np.concatenate(data_idxs, axis=0) all_preds = np.concatenate(all_preds, axis=0) write_predictions( self.val_path, os.path.join(self.record_path, "results.txt"), data_idxs, all_preds, delimiter=self.delimiter, skip_first=self.skip_first ) # output the p-r values for future plotting P-R Curve predicts = np.concatenate(predicts, axis=0) truths = np.concatenate(truths, axis=0) values = precision_recall_curve(truths, predicts) with open(os.path.join(self.record_path, "pr.values"), "wb") as f: pickle.dump(values, f) p_value, r_value, _ = values # plot P-R Curve if specified if arg.image: plt.figure() plt.plot( p_value, r_value, label="%s (ACC: %.2f, F1: %.2f)" % (self.model_name, accuracy_meter.avg, f1_meter.avg) ) plt.legend(loc="best") plt.title("2-Classes P-R curve") plt.xlabel("precision") plt.ylabel("recall") plt.savefig(os.path.join(self.record_path, "P-R.png")) plt.show()
def forward(self, p, Theta, Tx, Rx, num = 1): # input_theta: (train_sample_num, L, Ml, 2) # K: users; L: number of IRS; Ml: elements of IRS; # M: receiver antennas; N: transmitter antennas # K, L, Ml, M, Nm = 3, 1, 2, 1, 1 # Hd : (train_sample_num, K, K, M, N, 2) # transmitter-receiver # Tx : (train_sample_num, L, K, Ml, N, 2) # transmitter-IRS # Rx : (train_sample_num, K, L, M, Ml, 2) # IRS-receiver # theta : (train_sample_num, L, L, Ml, Ml, 2) device = Tx.device train_sample_num = Tx.shape[0] K, L, Ml, M, N = self.Ksize Tx_matrix = Tx.permute(0, 1, 3, 2, 4, 5).contiguous().view(train_sample_num, L*Ml, K*N, 2) Rx_matrix = Rx.permute(0, 1, 3, 2, 4, 5).contiguous().view(train_sample_num, K*M, L*Ml, 2) Theta_matrix = Theta.view(train_sample_num, L*Ml, 2) Tx_real, Tx_imag = torch.select(Tx_matrix, -1, 0), torch.select(Tx_matrix, -1, 1) Rx_real, Rx_imag = torch.select(Rx_matrix, -1, 0), torch.select(Rx_matrix, -1, 1) Theta_real, Theta_imag = torch.select(Theta_matrix, -1, 0).diag_embed(), torch.select(Theta_matrix, -1, 1).diag_embed() Rx_Theta_real = torch.matmul(Rx_real, Theta_real) - torch.matmul(Rx_imag, Theta_imag) Rx_Theta_imag = torch.matmul(Rx_real, Theta_imag) + torch.matmul(Rx_imag, Theta_real) h_real = torch.matmul(Rx_Theta_real, Tx_real) - torch.matmul(Rx_Theta_imag, Tx_imag) h_imag = torch.matmul(Rx_Theta_real, Tx_imag) + torch.matmul(Rx_Theta_imag, Tx_real) h_square = h_real**2 + h_imag**2 h_gain = (p**2).view(train_sample_num, -1, K*N)*h_square numerator = h_gain.diagonal(dim1 = 1, dim2 = 2) denominator = ( h_gain - numerator.diag_embed() ).sum(dim = 2) + self.sigma2 mu = numerator / denominator numerator = torch.sqrt( (1 + mu) * h_gain.diagonal(dim1 = 1, dim2 = 2)) denominator = h_gain.sum(dim = 2) + self.sigma2 alpha = numerator / denominator / math.sqrt(2) numerator = (1 + mu) * alpha**2 * h_square.diagonal(dim1 = 1, dim2 = 2) denominator = ( ( (alpha**2).view(train_sample_num, K*N, -1) * h_square ).sum(dim = 1) )**2 p_out = torch.min(torch.ones_like(p)*self.P_max, torch.sqrt(numerator / (2*denominator))) p_out = p_out.detach() numerator_real = 1/math.sqrt(2) * torch.sqrt(1 + mu)*p_out * h_real.diagonal(dim1 = 1, dim2 = 2) numerator_imag = 1/math.sqrt(2) * torch.sqrt(1 + mu)*p_out * h_imag.diagonal(dim1 = 1, dim2 = 2) h_gain = (p_out**2).view(train_sample_num, -1, K*N)*h_square denominator = h_gain.sum(dim = 2) + self.sigma2 beta_real = numerator_real / denominator beta_imag = numerator_imag / denominator theta_out = torch.zeros(train_sample_num, L, Ml, 2).to(device) for sample in range(train_sample_num): if sample%100 == 0: print('%5d'%sample, end = '') theta_out[sample] = self.cvx_opt(cp_p = p_out.cpu().numpy()[sample]**2, cp_Tx_real = Tx_real.cpu().numpy()[sample], cp_Tx_imag = Tx_imag.cpu().numpy()[sample], cp_Rx_real = Rx_real.cpu().numpy()[sample], cp_Rx_imag = Rx_imag.cpu().numpy()[sample], cp_beta_real = beta_real.cpu().numpy()[sample], cp_beta_imag = beta_imag.cpu().numpy()[sample], cp_mu = mu.cpu().numpy()[sample]) print() return p_out, theta_out
def make_registration_image_summary(source_image, target_image, warped_source_image, disp_field, deform_field, source_seg=None, target_seg=None, warped_source_seg=None, n_slices=1, n_samples=1): """ make image summary for tensorboard the image/seg grid are ordered in row by source, warped source, target and in column by HW, DW, DH slice the deform/disp field are ordered in row by [D, H, W]? value, target and in column by HW, DW, DH slice :param source_image: torch.tensor, NxCxDxHxW, 3D image volume (C:channels) :param target_image: torch.tensor, NxCxDxHxW, 3D image volume (C:channels) :param warped_source: torch.tensor, NxCxDxHxW, 3D image volume (C:channels) :param disp_field: torch.tensor, Nx3xDxHxW, 3D image volume, =deform_field -identity_transform :param deform_field: torch.tensor, Nx3xDxHxW, 3D image volume normalized in range [-1,1] :param n_slices: int, number of slices from a image volume :param n_samples: int, number of samples in a batch used from summary :param source_seg: :param warped_source_seg: :param target_seg: :return: """ n_samples = min(n_samples, source_image.size()[0]) grids = {} image_slices = [] disp_slices = [] seg_slices = [] deform_grid_slices = [] max_size = torch.tensor(source_image.shape[2:]).max().item() for n in range(n_samples): for axis in range(3): axis += 1 # slice_ind = torch.arange(0, source_image.size()[axis], source_image.size()[axis + 2]/(n_slices+1))[1:] slice_ind = source_image.size()[axis + 1] // 2 source_image_slice = torch.select(source_image[n, :, :, :, :], axis, slice_ind) warped_source_image_slice = torch.select( warped_source_image[n, :, :, :, :], axis, slice_ind) target_image_slice = torch.select(target_image[n, :, :, :, :], axis, slice_ind) image_slices += [ source_image_slice, warped_source_image_slice, target_image_slice ] disp_field_slice = torch.select(disp_field[n, :, :, :, :], axis, slice_ind) # disp_slices += [disp_field_slice[0:1, :, :], disp_field_slice[1:2, :, :], # disp_field_slice[2:3, :, :]] disp_slices += [disp_field_slice] deform_field_slice = torch.select(deform_field[n, :, :, :, :], axis, slice_ind) deform_grid_slice = torch.from_numpy( generate_deform_grid(deform_field_slice, axis - 1, warped_source_image_slice)) deform_grid_slices += [deform_grid_slice] if (source_seg is not None) and (target_seg is not None) and ( warped_source_seg is not None): source_seg_slice = torch.select(source_seg[n, :, :, :], axis - 1, slice_ind) source_seg_slice = labels2colors( source_seg_slice, images=source_image_slice.squeeze(0), overlap=True) target_seg_slice = torch.select(target_seg[n, :, :, :], axis - 1, slice_ind) target_seg_slice = labels2colors( target_seg_slice, images=target_image_slice.squeeze(0), overlap=True) warped_source_seg_slice = torch.select( warped_source_seg[n, :, :, :], axis - 1, slice_ind) warped_source_seg_slice = labels2colors( warped_source_seg_slice, images=warped_source_image_slice.squeeze(0), overlap=True) seg_slices += [ source_seg_slice, warped_source_seg_slice, target_seg_slice ] grids['images'] = vision_utils.make_grid(slices_padding(image_slices), pad_value=1, nrow=3, normalize=True, range=(0, 1)) if seg_slices: grids['masks'] = vision_utils.make_grid(slices_padding(seg_slices), pad_value=1, nrow=3) grids['disp_field'] = vision_utils.make_grid( slices_padding(disp_slices), pad_value=1, nrow=1, normalize=True, range=(-0.1, 0.1)) grids['deform_grid'] = vision_utils.make_grid( slices_padding(deform_grid_slices), pad_value=1, nrow=1) return grids
def func(z): z_ = torch.view_as_complex(z) z_select = torch.select(z_, z_.dim() - 1, 0) z_select_real = torch.view_as_real(z_select) return z_select_real.sum()
def forward(self, x): return torch.select(x, self.dim, self.pos)
def forward(self, x): a = torch.tensor([[1., 2., 3.], [4., 5., 6.]]) b = torch.select(a, dim=1, index=-2) c = torch.index_select(a, dim=-2, index=torch.tensor([0, 1])) return b + 1, c + x
def forward(self, x): x0 = torch.select(x, 0, 0) x1 = torch.select(x, 0, 1) x2 = torch.select(x, 0, 2) y0 = torch.select(x, 1, 0) y1 = torch.select(x, 1, 1) y2 = torch.select(x, 1, 2) y3 = torch.select(x, 1, 3) z0 = torch.select(x, 2, 0) z1 = torch.select(x, 2, 1) z2 = torch.select(x, 2, 2) z3 = torch.select(x, 2, 3) z4 = torch.select(x, 2, 4) return x0, x1, x2, y0, y1, y2, y3, z0, z1, z2, z3, z4