def check_invariance(self, r2_act, out_type, data, func, data_type=None): _, c, h, w = data.shape input_type = enn.FieldType(r2_act, self.c * [r2_act.trivial_repr]) y = func(data) for g in r2_act.testing_elements: data = enn.GeometricTensor( data.tensor.view(-1, c, h, w).cpu(), input_type) x_transformed = enn.GeometricTensor( data.transform(g).tensor.view(-1, c, h, w).cuda(), input_type) y_from_x_transformed = func(x_transformed) y_transformed_from_x = y # Invariance Condition data = enn.GeometricTensor( data.tensor.squeeze().view(-1, c, h, w).cuda(), input_type) # assert torch.allclose(output_conv.squeeze(), output_rg_conv.squeeze(), atol=1e-5), g print_y = y_from_x_transformed.tensor.detach().to( 'cpu').numpy().squeeze() print("{:4d} : {}".format(g, print_y)) assert torch.allclose(y_from_x_transformed.tensor.squeeze(), y_transformed_from_x.tensor.squeeze(), atol=1e-5), g print("Passed Invariance Test")
def check_equivariance(r2_act, out_type, data, func, data_type=None): input_type = enn.FieldType(r2_act, [r2_act.trivial_repr]) if data_type == 'GeomTensor': data = enn.GeometricTensor(data.view(-1, 1, 1, 2), input_type) for g in r2_act.testing_elements: output = func(data) if data_type == 'GeomTensor': rg_output = enn.GeometricTensor( output.tensor.view(-1, 1, 1, 2).cpu(), out_type).transform(g) data = enn.GeometricTensor( data.tensor.view(-1, 1, 1, 2).cpu(), input_type) x_transformed = enn.GeometricTensor( data.transform(g).tensor.view(-1, 1, 1, 2), input_type) else: rg_output = enn.GeometricTensor( output.view(-1, 1, 1, 2).cpu(), out_type).transform(g) data = enn.GeometricTensor( data.view(-1, 1, 1, 2).cpu(), input_type) x_transformed = data.transform(g).tensor.view(-1, 1, 1, 2) output_rg = func(x_transformed) # Equivariance Condition if data_type == 'GeomTensor': output_rg = enn.GeometricTensor(output_rg.tensor.cpu(), out_type) data = enn.GeometricTensor(data.tensor.squeeze().view(-1, 1, 1, 2), input_type) else: output_rg = enn.GeometricTensor( output_rg.view(-1, 1, 1, 2).cpu(), out_type) data = data.tensor.squeeze() assert torch.allclose(rg_output.tensor.cpu().squeeze(), output_rg.tensor.squeeze(), atol=1e-5), g
def inverse(self, z): log_det_J, x = z.new_zeros(z.shape[0]), z.view(-1, self.c, self.h, self.w) B, C, H, W = x.size() x = enn.GeometricTensor(x, self.input_type) for i in range(0, self.n_blocks): filter = list(self.flow_model[i] [0].named_modules())[0][1].expand_parameters()[0] x = self.invertible_tanh(x.tensor, inverse=True) x = inv_conv_exp(x, filter, terms=self.n_terms) x = enn.GeometricTensor(x, self.input_type) # x = self.activation_fn_applied(x) log_det_J += log_det(filter) * H * W return x.tensor.squeeze(), log_det_J
def dummy_func(self, z): for i in reversed(range(0, self.n_blocks)): z = z.tensor fiber_batch_mask = self.mask[i].view(self.c, 1, 1).repeat( z.shape[0], 1, 1, 1) z_ = fiber_batch_mask * z z_ = enn.GeometricTensor(z_, self.input_type) s = self.s[i](z_).tensor t = self.t[i](z_).tensor inverse_fiber_batch_mask = (1 - self.mask[i]).view( self.c, 1, 1).repeat(z.shape[0], 1, 1, 1) z = inverse_fiber_batch_mask * (z - t) * torch.exp( (-1. * s)) + z_.tensor z = enn.GeometricTensor(z, self.input_type) return z
def check_invariance(r2_act, out_type, data, func): input_type = enn.FieldType(r2_act, [r2_act.trivial_repr]) data = enn.GeometricTensor(data.view(-1, 1, 1, 2), input_type) for g in r2_act.testing_elements: log_prob = func(data) data = enn.GeometricTensor( data.tensor.view(-1, 1, 1, 2).cpu(), input_type) x_transformed = enn.GeometricTensor( data.transform(g).tensor.cuda().view(-1, 1, 1, 2), input_type) invar_new_log_prob = func(x_transformed) data = enn.GeometricTensor( data.tensor.squeeze().cuda().view(-1, 1, 1, 2), input_type) assert torch.allclose(log_prob.tensor.cpu().squeeze(), equivar_log_prob.tensor.squeeze(), atol=1e-5), g
def forward(self, x, logpx=None): _, c, h, w = x.shape if logpx is None: y = x + self.nnet(x) return y else: g, logdetgrad = self._logdetgrad(x) if torch.is_tensor(g): try: g = enn.GeometricTensor(g, self.nnet[-1].out_type) except: g = enn.GeometricTensor(g, self.nnet.nnet[-1].out_type) if torch.is_tensor(x): x = enn.GeometricTensor(x.view(-1, c, h, w), self.nnet[0].in_type) return x + g, logpx - logdetgrad
def forward(self, input): # apply each equivariant block # Each layer has an input and an output type # As a result, consecutive layers need to have matching input/output types # x = enn.GeometricTensor(input, self.input_type) x = self.block1(input) x = self.pool1(x) x = self.block2(x) x = self.pool2(x) x = self.block3(x) # pool over the spatial dimensions x = self.pool3(x) # pool over the group x = self.gpool(x) # unwrap the output GeometricTensor # (take the Pytorch tensor and discard the associated representation) x = x.tensor # Upsample DCGAN style out = self.l1(x.squeeze()) x = out.view(out.shape[0], 128, self.init_size, self.init_size) x = self.gen(x) x = enn.GeometricTensor(x, self.input_type) return x
def forward(self, x): x = enn.GeometricTensor(x.view(-1, 1, 1, 2), self.input_type) zero = torch.zeros(x.shape[0], 1).to(self.args.dev) # transform to z z, delta_logp = self.flow_model(x, zero) return z.tensor.squeeze(), delta_logp
def check_equivariance(self, r2_act, out_type, data, func, data_type=None): _, c, h, w = data.shape input_type = enn.FieldType(r2_act, self.c*[r2_act.trivial_repr]) for g in r2_act.testing_elements: output = func(data) rg_output = enn.GeometricTensor(output.tensor.view(-1, c, h, w).cpu(), out_type).transform(g) data = enn.GeometricTensor(data.tensor.view(-1, c, h, w).cpu(), input_type) x_transformed = enn.GeometricTensor(data.transform(g).tensor.view(-1, c, h, w).cuda(), input_type) output_rg = func(x_transformed) # Equivariance Condition output_rg = enn.GeometricTensor(output_rg.tensor.cpu(), out_type) data = enn.GeometricTensor(data.tensor.squeeze().view(-1, c, h , w).cuda(), input_type) assert torch.allclose(rg_output.tensor.cpu().squeeze(), output_rg.tensor.squeeze(), atol=1e-5), g print("Passed Equivariance Test")
def forward(self, x, inverse=False): if inverse: return self.inverse(x) log_det_J, z = x.new_zeros(x.shape[0]), x.view(-1, self.c, self.h, self.w) # my_x = enn.GeometricTensor(x, self.input_type) # self.check_equivariance(self.group_action_type, self.input_type, my_x, # self.dummy_func) for i in reversed(range(0, self.n_blocks)): fiber_batch_mask = self.mask[i].view(self.c, 1, 1).repeat( z.shape[0], 1, 1, 1) z_ = fiber_batch_mask * z z_ = enn.GeometricTensor(z_, self.input_type) # self.check_invariance(self.group_action_type, self.input_type, # z_, self.s[i]) # self.check_equivariance(self.group_action_type, self.input_type, # z_, self.s[i]) s = self.s[i](z_).tensor t = self.t[i](z_).tensor inverse_fiber_batch_mask = (1 - self.mask[i]).view( self.c, 1, 1).repeat(z.shape[0], 1, 1, 1) z = inverse_fiber_batch_mask * (z - t) * torch.exp( (-1. * s)) + z_.tensor log_det_J -= (inverse_fiber_batch_mask * s).sum(dim=(1, 2, 3)) return z.squeeze(), log_det_J.view(-1, 1)
def forward(self, x, logpx=None): in_type = x.type _, c, h, w = x.shape x = x.tensor c = x.size(1) if not self.initialized: with torch.no_grad(): # compute batch statistics x_t = x.transpose(0, 1).contiguous().view(c, -1) batch_mean = torch.mean(x_t, dim=1) batch_var = torch.var(x_t, dim=1) # for numerical issues batch_var = torch.max(batch_var, torch.tensor(0.2).to(batch_var)) self.bias.data.copy_(-batch_mean) self.weight.data.copy_(-0.5 * torch.log(batch_var)) self.initialized.fill_(1) bias = self.bias.view(*self.shape).expand_as(x) weight = self.weight.view(*self.shape).expand_as(x) # bias = self.bias.view(-1, c, h , w).expand_as(x) # weight = self.weight.view(-1, c, h , w).expand_as(x) y = (x + bias) * torch.exp(weight) y = enn.GeometricTensor(y.view(-1, c, h, w), in_type) if logpx is None: return y else: return y, logpx - self._logdetgrad(x)
def forward(self, x_in, logpx=None, inverse=False, classify=False): x = enn.GeometricTensor(x_in.view(-1, self.c, self.h, self.w), self.input_type) if inverse: return self.inverse(x, logpx) out = [] if classify: class_outs = [] for idx in range(len(self.transforms)): if logpx is not None: # self.check_equivariance(self.group_action_type, self.input_type, # x, self.transforms[idx]) x, logpx = self.transforms[idx].forward(x, logpx) else: x = self.transforms[idx].forward(x) if self.factor_out and (idx < len(self.transforms) - 1): d = x.size(1) // 2 x, f = x[:, :d], x[:, d:] out.append(f) # Handle classification. if classify: if self.factor_out: class_outs.append(self.classification_heads[idx](f).tensor) else: class_outs.append(self.classification_heads[idx](x).tensor) out.append(x.tensor.squeeze()) out = torch.cat([o.view(o.size()[0], -1) for o in out], 1) output = out if logpx is None else (out, logpx) if classify: h = torch.cat(class_outs, dim=1).squeeze(-1).squeeze(-1) logits = self.logit_layer(h) return output, logits else: return output
def forward(self, input: torch.Tensor): # wrap the input tensor in a GeometricTensor # (associate it with the input type) x = nn_e2.GeometricTensor(input, self.input_type) # apply each equivariant block # Each layer has an input and an output type # A layer takes a GeometricTensor in input. # This tensor needs to be associated with the same representation of the layer's input type # # The Layer outputs a new GeometricTensor, associated with the layer's output type. # As a result, consecutive layers need to have matching input/output types x = self.block1(x) x = self.pool1(x) x = self.block2(x) x = self.pool2(x) # unwrap the output GeometricTensor # (take the Pytorch tensor and discard the associated representation) x = x.tensor # classify with the final fully connected layers) x = x.view(-1, 24 * 13 * 13) x = self.fully_net(x) return x
def forward(self, x): # wrap the input tensor in a GeometricTensor x = enn.GeometricTensor(x, self.in_type) out = self.conv1(x) out = self.layer1(out) out = self.layer2(self.restrict1(out)) out = self.layer3(self.restrict2(out)) out = self.bn(out) out = self.relu(out) # extract the tensor from the GeometricTensor to use the common Pytorch operations out = out.tensor b, c, w, h = out.shape out = F.avg_pool2d(out, (w, h)) out = out.view(out.size(0), -1) out = self.linear(out) return out
def forward(self, *args): self._update_u_v() if torch.is_tensor(args[0]): inp = enn.GeometricTensor(args[0], self.module.in_type) else: inp = args[0] # return self.module.forward(*args) return self.module.forward(inp)
def forward(self, x, inverse=False): if inverse: return self.inverse(x) log_det_J, z = x.new_zeros(x.shape[0]), x.view(-1, self.c, self.h, self.w) B, C, H, W = z.size() z = enn.GeometricTensor(z, self.input_type) for i in reversed(range(0, self.n_blocks)): filter = list(self.flow_model[i] [0].named_modules())[0][1].expand_parameters()[0] # ipdb.set_trace() z = conv_exp(z.tensor, filter, terms=self.n_terms) z = self.invertible_tanh(z) z = enn.GeometricTensor(z, self.input_type) # z = self.activation_fn_applied(z) log_det_J -= log_det(filter) * H * W return z.tensor.squeeze(), log_det_J.view(-1, 1)
def forward(self, x): x = nn.GeometricTensor(x, self.feat_type_in) out_conv1 = self.conv1(x) out_conv2 = self.conv2_1(self.conv2(out_conv1)) out_conv3 = self.conv3_1(self.conv3(out_conv2)) out_conv4 = self.conv4_1(self.conv4(out_conv3)) out_deconv3 = self.deconv3(out_conv4.tensor) concat3 = torch.cat((out_conv3.tensor, out_deconv3.tensor), 1) out_deconv2 = self.deconv2(concat3) concat2 = torch.cat((out_conv2.tensor, out_deconv2.tensor), 1) out_deconv1 = self.deconv1(concat2) concat0 = torch.cat((x.tensor, out_deconv1.tensor), 1) concat0 = nn.GeometricTensor(concat0, self.feat_type_hid_out) out = self.output_layer(concat0) return out.tensor
def inverse(self, y, logpy=None): in_type = y.type _, c, h, w = y.shape x = (torch.sigmoid(y.tensor) - self.alpha) / (1 - 2 * self.alpha) x = enn.GeometricTensor(x.view(-1, c, h, w), in_type) if logpy is None: return x return x, logpy + self._logdetgrad(x.tensor).view(x.size(0), -1).sum( 1, keepdim=True)
def forward(self, x1, x2): x1 = x1.tensor crop1 = self.center_crop(x1, x2.shape[2:]) crop1 = enn.GeometricTensor(crop1, self.mid_type) concat = enn.tensor_directsum([crop1, x2]) out = self.conv1(concat) out = self.conv2(out) return out
def forward(self, x, logpx=None): in_type = x.type _, c, h, w = x.shape s = self.alpha + (1 - 2 * self.alpha) * x.tensor y = torch.log(s) - torch.log(1 - s) y = enn.GeometricTensor(y.view(-1, c, h, w), in_type) if logpx is None: return y return y, logpx - self._logdetgrad(x.tensor).view( x.tensor.size(0), -1).sum(1, keepdim=True)
def forward(self, X): ''' Input: X - torch.tensor - shape (batch_size,n_in_channels,m,n) Output: torch.tensor - shape (batch_size,n_out_channels,m,n) ''' #Convert X into a geometric tensor: X = G_CNN.GeometricTensor(X, self.feature_emb) #Send it through the decoder: Out = self.decoder(X) #Return the resulting tensor: return (Out.tensor)
def build_steer_cnn_2d( in_field_type, hidden_field_types, kernel_sizes, out_field_type, gspace, activation="relu", padding_mode="zeros", modify_init=1.0, ): """ Input: in_rep - rep of representation of the input data hidden_reps - the reps to use in the hidden layers kernel sizes - the size of the kernel used in each layer out_rep - the rep to use in the ouput layer activation - the activation to use between layers gspace - the gsapce that data lives in """ if isinstance(kernel_sizes, int): kernel_sizes = [kernel_sizes] * (len(hidden_reps) + 1) layer_field_types = [in_field_type, *hidden_field_types, out_field_type] layers = [] for i in range(len(layer_field_types) - 1): layers.append( gnn.R2Conv( layer_field_types[i], layer_field_types[i + 1], kernel_sizes[i], padding=int((kernel_sizes[i] - 1) / 2), padding_mode=padding_mode, initialize=True, )) if i != len(layer_field_types) - 2: layers.append(activations[activation](layer_field_types[i + 1])) cnn = gnn.SequentialModule(*layers) # TODO: dirty fix to alleviate weird initialisations for p in cnn.parameters(): if p.dim() == 0: p.data = p.data * modify_init else: p.data[:] = p.data * modify_init return nn.Sequential( Expression(lambda X: gnn.GeometricTensor(X, in_field_type)), cnn, Expression(lambda X: X.tensor), )
def _inverse_fixed_point(self, y, atol=1e-5, rtol=1e-5): _, c, h, w = y.shape if torch.is_tensor(y): try: y = enn.GeometricTensor(y.view(-1, c, h, w), self.nnet[0].in_type) except: y = enn.GeometricTensor(y.view(-1, c, h, w), self.nnet.nnet[0].in_type) x, x_prev = y.tensor - self.nnet(y).tensor, y.tensor i = 0 tol = atol + y.tensor.abs() * rtol while not torch.all((x - x_prev)**2 / tol < 1): if torch.is_tensor(x): try: x = enn.GeometricTensor(x.view(-1, c, h, w), self.nnet[0].in_type) except: x = enn.GeometricTensor(x.view(-1, c, h, w), self.nnet.nnet[0].in_type) x, x_prev = y.tensor - self.nnet(x).tensor, x.tensor i += 1 if i > 1000: print('Iterations exceeded 1000 for inverse.') break try: x = enn.GeometricTensor(x.view(-1, c, h, w), self.nnet.nnet[-1].out_type) except: x = enn.GeometricTensor(x.view(-1, c, h, w), self.nnet[-1].out_type) return x
def features(self, x): x = enn.GeometricTensor(x, self.in_type) out = self.conv1(x) x1 = self.layer1(out) x2 = self.layer2(self.restrict1(x1)) x3 = self.layer3(self.restrict2(x2)) return x1, x2, x3
def forward(self, x): x = nn.GeometricTensor(x, self.feat_type_in) out = self.relu(self.bn(self.convin(x))) out = self.relu(self.bn(self.convhid(out))) out = self.relu(self.bn(self.convhid(out))) out = self.convout(out) out = self.invariant_map(self.avgpool(out)) out = self.lin_in(out.tensor.mean(-1).mean(-1)) out = self.elu(out) out = self.lin_out(out) return out
def forward(self, x): x = enn.GeometricTensor(x, enn.FieldType(self.gspace, 3*[self.gspace.trivial_repr])) x = self.conv1(x) x = self.bn1(x) x = self.relu1(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = x.tensor x = self.pool(x) x = x.view(x.size(0), -1) x = self.fc(x) return x
def forward(self, input): if not self.initialized: self.spatial_dims.copy_( torch.tensor(input.shape[2:4]).to(self.spatial_dims)) weight, expanded_bias = self.compute_weight(update=False) # weight, expanded_bias = self.compute_weight(update=True) is_this_a_tensor = torch.is_tensor(input) if is_this_a_tensor: output = F.conv2d(input, weight, expanded_bias, self.stride, self.padding, 1, 1) else: # Clearly this is a Geometric Tensor output = F.conv2d(input.tensor, weight, expanded_bias, self.stride, self.padding, 1, 1) return enn.GeometricTensor(output, self.out_type)
def feature_shape(self): print("Printing network feature fields shape") x = torch.zeros(1, self.input_shape[0], self.input_shape[1], self.input_shape[2]) x = nn.GeometricTensor(x, self.input_type) x = self.feature_field1(x) print(x.shape) x = self.feature_field2(x) print(x.shape) x = self.feature_field3(x) print(x.shape) x = self.equivariant_features(x) print(x.shape) x = x.tensor b, c, h, w = x.shape return h, w
def features(self, x): if isinstance(x, enn.GeometricTensor): assert x.type == self.in_type else: x = enn.GeometricTensor(x, self.in_type) out = self.conv1(x) x1 = self.layer1(out) x2 = self.layer2(self.restrict1(x1)) x3 = self.layer3(self.restrict2(x2)) return x1, x2, x3
def generate_2d_rot8(out_path): r2_act = gspaces.Rot2dOnR2(N=8) feat_type_in = gnn.FieldType(r2_act, [r2_act.trivial_repr]) feat_type_out = gnn.FieldType(r2_act, 3 * [r2_act.regular_repr]) conv = gnn.R2Conv(feat_type_in, feat_type_out, kernel_size=3, bias=False) xs, ys, ws = [], [], [] for task_idx in range(10000): gnn.init.generalized_he_init(conv.weights, conv.basisexpansion) inp = gnn.GeometricTensor(torch.randn(20, 1, 32, 32), feat_type_in) result = conv(inp).tensor.detach().cpu().numpy() xs.append(inp.tensor.detach().cpu().numpy()) ys.append(result) ws.append(conv.weights.detach().cpu().numpy()) if task_idx % 100 == 0: print(f"Finished generating task {task_idx}") xs, ys, ws = np.stack(xs), np.stack(ys), np.stack(ws) np.savez(out_path, x=xs, y=ys, w=ws)