def extract(self, im: np.ndarray, debug_save_name=None): with fluid.dygraph.guard(): if debug_save_name is not None: np.savez(debug_save_name, im) im = im / 255. # don't use im /= 255. since we don't want to alter the input im -= self.mean im /= self.std im = n2p(im) output_features = self.net.extract_features(im, self.feature_layers) # Store the raw resnet features which are input to iounet iounet_backbone_features = TensorList([ output_features[layer] for layer in self.iounet_feature_layers ]) self.iounet_backbone_features = iounet_backbone_features.numpy() # Store the processed features from iounet, just before pooling self.iounet_features = TensorList([ f.numpy() for f in self.iou_predictor.get_iou_feat( iounet_backbone_features) ]) output = TensorList([ output_features[layer].numpy() for layer in self.output_layers ]) return output
def extract(self, im: np.ndarray, debug_save_name=None): with fluid.dygraph.guard(): if debug_save_name is not None: np.savez(debug_save_name, im) im = im / 255. # don't use im /= 255. since we don't want to alter the input im -= self.mean im /= self.std im = n2p(im) output_features = self.net.extract_features( im, self.feature_layers) # Store the raw backbone features which are input to estimator estimator_backbone_features = TensorList([ output_features[layer] for layer in self.estimator_feature_layers ]) self.estimator_backbone_features = estimator_backbone_features.numpy( ) output = TensorList([ output_features[layer].numpy() for layer in self.output_layers ]) return output
def run_newton_iter(self, num_cg_iter): self.x.requires_grad_(True) # Evaluate function at current estimate self.f0 = self.problem(self.x) if self.debug and not self.analyze_convergence: self.losses = torch.cat( (self.losses, self.f0.detach().cpu().view(-1))) # Gradient of loss self.g = TensorList( torch.autograd.grad(self.f0, self.x, create_graph=True)) # Get the right hand side self.b = -self.g.detach() # Run CG delta_x, res = self.run_CG(num_cg_iter, eps=self.cg_eps) self.x.detach_() self.x += delta_x if self.debug: self.residuals = torch.cat((self.residuals, res))
def run_GN_iter(self, num_cg_iter): """Runs a single GN iteration.""" self.x.requires_grad_(True) # Evaluate function at current estimate self.f0 = self.problem(self.x) # Create copy with graph detached self.g = self.f0.detach() if self.debug and not self.analyze_convergence: loss = self.problem.ip_output(self.g, self.g) # print('Loss:',loss) self.losses = torch.cat( (self.losses, loss.detach().cpu().view(-1))) self.g.requires_grad_(True) # Get df/dx^t @ f0 self.dfdxt_g = TensorList( torch.autograd.grad(self.f0, self.x, self.g, create_graph=True)) # Get the right hand side self.b = -self.dfdxt_g.detach() # Run CG delta_x, res = self.run_CG(num_cg_iter, eps=self.cg_eps) self.x.detach_() self.x += delta_x if self.debug: self.residuals = torch.cat((self.residuals, res))
def size(self, im_sz): if self.output_size is None: return TensorList([floordiv(im_sz, s) for s in self.stride()]) if isinstance(im_sz, PTensor): return TensorList([ floordiv(im_sz, s) if sz is None else np.array([sz[0], sz[1]]) for sz, s in zip(self.output_size, self.stride()) ])
def run(self, num_cg_iter, joint_var): """Run the oprimizer with the provided number of iterations.""" if num_cg_iter == 0: return self.x = joint_var lossvec = None if self.debug: lossvec = torch.zeros(2) self.x.requires_grad_(True) # Evaluate function at current estimate self.f0 = self.problem(self.x) # Create copy with graph detached self.g = self.f0.detach() if self.debug: lossvec[0] = self.problem.ip_output(self.g, self.g) self.g.requires_grad_(True) # Get df/dx^t @ f0 self.dfdxt_g = TensorList( torch.autograd.grad(self.f0, self.x, self.g, create_graph=True)) # Get the right hand side self.b = -self.dfdxt_g.detach() # Run CG delta_x, res = self.run_CG(num_cg_iter, eps=self.cg_eps) self.x.detach_() self.x += delta_x if self.debug: self.f0 = self.problem(self.x) lossvec[-1] = self.problem.ip_output(self.f0, self.f0) self.residuals = torch.cat((self.residuals, res)) self.losses = torch.cat((self.losses, lossvec)) # print('Loss:', self.losses) if self.visdom is not None: self.visdom.register(self.losses, 'lineplot', 3, 'Loss') self.visdom.register(self.residuals, 'lineplot', 3, 'CG residuals') elif self.plotting: plot_graph(self.losses, self.fig_num[0], title='Loss') plot_graph(self.residuals, self.fig_num[1], title='CG residuals') self.x.detach_() self.clear_temp()
def get_attribute(self, name: str, ignore_missing: bool=False): if ignore_missing: return TensorList([ getattr(f, name) for f in self.features if self._return_feature(f) and hasattr(f, name) ]) else: return TensorList([ getattr(f, name, None) for f in self.features if self._return_feature(f) ])
def evaluate_CG_iteration(self, delta_x): if self.analyze_convergence: x = (self.x + delta_x).detach() x.requires_grad_(True) # compute loss and gradient loss = self.problem(x) grad = TensorList(torch.autograd.grad(loss, x)) # store in the vectors self.losses = torch.cat((self.losses, loss.detach().cpu().view(-1))) self.gradient_mags = torch.cat((self.gradient_mags, sum(grad.view(-1) @ grad.view(-1)).cpu().sqrt().detach().view(-1)))
def run(self, num_iter, dummy=None): if num_iter == 0: return lossvec = None if self.debug: lossvec = torch.zeros(num_iter + 1) grad_mags = torch.zeros(num_iter + 1) for i in range(num_iter): self.x.requires_grad_(True) # Evaluate function at current estimate self.f0 = self.problem(self.x) # Compute loss loss = self.problem.ip_output(self.f0, self.f0) # Compute grad grad = TensorList(torch.autograd.grad(loss, self.x)) # Update direction if self.dir is None: self.dir = grad else: self.dir = grad + self.momentum * self.dir self.x.detach_() self.x -= self.step_legnth * self.dir if self.debug: lossvec[i] = loss.item() grad_mags[i] = sum(grad.view(-1) @ grad.view(-1)).sqrt().item() if self.debug: self.x.requires_grad_(True) self.f0 = self.problem(self.x) loss = self.problem.ip_output(self.f0, self.f0) grad = TensorList(torch.autograd.grad(loss, self.x)) lossvec[-1] = self.problem.ip_output(self.f0, self.f0).item() grad_mags[-1] = sum( grad.view(-1) @ grad.view(-1)).cpu().sqrt().item() self.losses = torch.cat((self.losses, lossvec)) self.gradient_mags = torch.cat((self.gradient_mags, grad_mags)) if self.visdom is not None: self.visdom.register(self.losses, 'lineplot', 3, 'Loss') self.visdom.register(self.gradient_mags, 'lineplot', 4, 'Gradient magnitude') elif self.plotting: plot_graph(self.losses, self.fig_num[0], title='Loss') plot_graph(self.gradient_mags, self.fig_num[1], title='Gradient magnitude') self.x.detach_() self.clear_temp()
def A(self, x): dfdx_x = torch.autograd.grad(self.dfdxt_g, self.g, x, retain_graph=True) return TensorList( torch.autograd.grad(self.f0, self.x, dfdx_x, retain_graph=True))
def get_fparams(self, name: str=None): if name is None: return [f.fparams for f in self.features if self._return_feature(f)] return TensorList([ getattr(f.fparams, name) for f in self.features if self._return_feature(f) ]).unroll()
def extract_transformed(self, im, pos, scale, image_sz, transforms, debug_save_name=None): """Extract features from a set of transformed image samples. args: im: Image. pos: Center position for extraction. scale: Image scale to extract features from. image_sz: Size to resize the image samples to before extraction. transforms: A set of image transforms to apply. """ # Get image patche im_patch = sample_patch(im, pos, scale * image_sz, image_sz) # Apply transforms with fluid.dygraph.guard(fluid.CPUPlace()): im_patches = np.stack([T(im_patch) for T in transforms]) if debug_save_name is not None: np.save(debug_save_name, im_patches) im_patches = np.transpose(im_patches, (0, 3, 1, 2)) # Compute features feature_map = TensorList( [f.get_feature(im_patches) for f in self.features]).unroll() return feature_map
def forward(self, meta_parameter: TensorList, feat, label, sample_weight=None): # Assumes multiple filters, i.e. (sequences, filters, feat_dim, fH, fW) filter = meta_parameter[0] num_images = feat.shape[0] num_sequences = feat.shape[1] if feat.dim() == 5 else 1 # Compute scores scores = filter_layer.apply_filter( feat, filter, dilation_factors=self.filter_dilation_factors) if sample_weight is None: sample_weight = math.sqrt(1.0 / num_images) elif isinstance(sample_weight, torch.Tensor): if sample_weight.numel() == scores.numel(): sample_weight = sample_weight.view(scores.shape) elif sample_weight.dim() == 1: sample_weight = sample_weight.view(-1, 1, 1, 1, 1) label = label.view(scores.shape) data_residual = sample_weight * (scores - label) # Compute regularization residual. Put batch in second dimension reg_residual = self.filter_reg * filter.view(1, num_sequences, -1) return TensorList([data_residual, reg_residual])
def extract(self, im, pos, scales, image_sz, debug_save_name=None): """Extract features. args: im: Image. pos: Center position for extraction. scales: Image scales to extract features from. image_sz: Size to resize the image samples to before extraction. """ if isinstance(scales, (int, float)): scales = [scales] # Get image patches with fluid.dygraph.guard(fluid.CPUPlace()): im_patches = np.stack([ sample_patch(im, pos, s * image_sz, image_sz) for s in scales ]) if debug_save_name is not None: np.save(debug_save_name, im_patches) im_patches = np.transpose(im_patches, (0, 3, 1, 2)) # Compute features feature_map = TensorList( [f.get_feature(im_patches) for f in self.features]).unroll() return feature_map
def ltr_collate_stack1(batch): """Puts each data field into a tensor. The tensors are stacked at dim=1 to form the batch""" error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" elem_type = type(batch[0]) if isinstance(batch[0], torch.Tensor): out = None if _check_use_shared_memory(): # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum([x.numel() for x in batch]) storage = batch[0].storage()._new_shared(numel) out = batch[0].new(storage) return torch.stack(batch, 1, out=out) # if batch[0].dim() < 4: # return torch.stack(batch, 0, out=out) # return torch.cat(batch, 0, out=out) elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ and elem_type.__name__ != 'string_': elem = batch[0] if elem_type.__name__ == 'ndarray': # array of string classes and object if torch.utils.data.dataloader.re.search( '[SaUO]', elem.dtype.str) is not None: raise TypeError(error_msg.format(elem.dtype)) return torch.stack([torch.from_numpy(b) for b in batch], 1) if elem.shape == (): # scalars py_type = float if elem.dtype.name.startswith('float') else int return torch.utils.data.dataloader.numpy_type_map[elem.dtype.name]( list(map(py_type, batch))) elif isinstance(batch[0], int_classes): return torch.LongTensor(batch) elif isinstance(batch[0], float): return torch.DoubleTensor(batch) elif isinstance(batch[0], string_classes): return batch elif isinstance(batch[0], TensorDict): return TensorDict({ key: ltr_collate_stack1([d[key] for d in batch]) for key in batch[0] }) elif isinstance(batch[0], collections.Mapping): return { key: ltr_collate_stack1([d[key] for d in batch]) for key in batch[0] } elif isinstance(batch[0], TensorList): transposed = zip(*batch) return TensorList( [ltr_collate_stack1(samples) for samples in transposed]) elif isinstance(batch[0], collections.Sequence): transposed = zip(*batch) return [ltr_collate_stack1(samples) for samples in transposed] elif batch[0] is None: return batch raise TypeError((error_msg.format(type(batch[0]))))
def __init__(self, *args, **kwargs): if len(args) > 0: raise ValueError for name, val in kwargs.items(): if isinstance(val, list): setattr(self, name, TensorList(val)) else: setattr(self, name, val)
def get_inputs(self, scope=''): if scope not in self.inputs_dict: name = scope + "training_samples" vars = create_var_list(name, self.sample_weights, [None]) training_samples_p = TensorList(vars) name = scope + "y" vars = create_var_list(name, self.y, [None]) y_p = TensorList(vars) name = scope + "sample_weights" vars = create_var_list(name, self.sample_weights, [None, 1]) sample_weights_p = TensorList(vars) self.inputs_dict[scope] = (training_samples_p, y_p, sample_weights_p) return self.inputs_dict[scope]
def extract(self, im: np.ndarray, debug_save_name=None): with fluid.dygraph.guard(): if debug_save_name is not None: np.savez(debug_save_name, im) im = n2p(im) output_features = self.net.extract_backbone_features(im) # Store the raw backbone features which are input to estimator output = TensorList([layer.numpy() for layer in output_features]) return output
def get_feature(self, im: np.ndarray): """Get the feature. Generally, call this function. args: im: image patch """ # Return empty tensor if it should not be used is_color = im.shape[1] == 3 if is_color and not self.use_for_color or not is_color and not self.use_for_gray: return np.array([]) feat_list = self.extract(im) output_sz = [None] * len( feat_list) if self.output_size is None else self.output_size # Pool/downsample with fluid.dygraph.guard(): feat_list = [n2p(f) for f in feat_list] for i, (sz, s) in enumerate(zip(output_sz, self.pool_stride)): if sz is not None: feat_list[i] = layers.adaptive_pool2d(feat_list[i], sz, pool_type='avg') elif s != 1: feat_list[i] = layers.pool2d(feat_list[i], s, pool_stride=s, pool_type='avg') # Normalize if self.normalize_power is not None: new_feat_list = [] for feat in feat_list: norm = (layers.reduce_sum(layers.reshape( layers.abs(feat), [feat.shape[0], 1, 1, -1])** self.normalize_power, dim=3, keep_dim=True) / (feat.shape[1] * feat.shape[2] * feat.shape[3]) + 1e-10)**(1 / self.normalize_power) feat = broadcast_op(feat, norm, 'div') new_feat_list.append(feat) feat_list = new_feat_list # To numpy feat_list = TensorList([f.numpy() for f in feat_list]) return feat_list
def extract(self, im, pos, scales, image_sz): if isinstance(scales, (int, float)): scales = [scales] # Get image patches im_patches = np.stack( [sample_patch(im, pos, s * image_sz, image_sz) for s in scales]) im_patches = np.transpose(im_patches, (0, 3, 1, 2)) # Compute features feature_map = layers.concat(TensorList( [f.get_feature(im_patches) for f in self.features]).unroll(), axis=1) return feature_map
def __call__(self, x: TensorList, scope=''): """ Compute residuals :param x: [filters] :return: [data_terms, filter_regularizations] """ training_samples, y, samples_weights = self.get_inputs(scope) # Do convolution and compute residuals residuals = operation.conv2d(training_samples, x, mode='same').apply( self.response_activation) residuals = residuals - y residuals = residuals * samples_weights.sqrt() # Add regularization for projection matrix residuals.extend( x.apply(static_identity) * self.filter_reg.apply(math.sqrt)) return residuals
def get_filter(self, feat, train_label, train_sw, num_objects=None, *args, **kwargs): """ Get the initial target model parameters given the few-shot labels :param feat: [1, 5, 512, 30, 52] :param train_label: [1, 5, 16, 30, 52] :return: """ # print("num_objects: {}".format(num_objects)) # training part is None if num_objects is None: weights = self.filter_initializer(feat, train_label) else: weights = self.filter_initializer(feat, train_label) weights = weights.repeat(1, num_objects, 1, 1, 1) # print("weights size(): {}".format(weights.size())) # [5, 16, 512, 3, 3] if self.filter_optimizer is not None: weights, weights_iter, losses = self.filter_optimizer( TensorList([weights]), feat=feat, label=train_label, sample_weight=train_sw, *args, **kwargs) weights = weights[0] weights_iter = [w[0] for w in weights_iter] else: weights_iter = [weights] losses = None return weights, weights_iter, losses
class ConjugateGradient_Attn(ConjugateGradientBase): """Conjugate Gradient optimizer, performing single linearization of the residuals in the start.""" def __init__(self, problem: L2Problem, variable: TensorList, cg_eps=0.0, fletcher_reeves=True, standard_alpha=True, direction_forget_factor=0, debug=False, plotting=False, visdom=None): super().__init__(fletcher_reeves, standard_alpha, direction_forget_factor, debug or plotting) self.problem = problem self.x = variable self.plotting = plotting self.fig_num = (10, 11) self.visdom = visdom self.cg_eps = cg_eps self.f0 = None self.g = None self.dfdxt_g = None self.residuals = torch.zeros(0) self.losses = torch.zeros(0) def clear_temp(self): self.f0 = None self.g = None self.dfdxt_g = None def run(self, num_cg_iter, joint_var): """Run the oprimizer with the provided number of iterations.""" if num_cg_iter == 0: return self.x = joint_var lossvec = None if self.debug: lossvec = torch.zeros(2) self.x.requires_grad_(True) # Evaluate function at current estimate self.f0 = self.problem(self.x) # Create copy with graph detached self.g = self.f0.detach() if self.debug: lossvec[0] = self.problem.ip_output(self.g, self.g) self.g.requires_grad_(True) # Get df/dx^t @ f0 self.dfdxt_g = TensorList( torch.autograd.grad(self.f0, self.x, self.g, create_graph=True)) # Get the right hand side self.b = -self.dfdxt_g.detach() # Run CG delta_x, res = self.run_CG(num_cg_iter, eps=self.cg_eps) self.x.detach_() self.x += delta_x if self.debug: self.f0 = self.problem(self.x) lossvec[-1] = self.problem.ip_output(self.f0, self.f0) self.residuals = torch.cat((self.residuals, res)) self.losses = torch.cat((self.losses, lossvec)) # print('Loss:', self.losses) if self.visdom is not None: self.visdom.register(self.losses, 'lineplot', 3, 'Loss') self.visdom.register(self.residuals, 'lineplot', 3, 'CG residuals') elif self.plotting: plot_graph(self.losses, self.fig_num[0], title='Loss') plot_graph(self.residuals, self.fig_num[1], title='CG residuals') self.x.detach_() self.clear_temp() def A(self, x): dfdx_x = torch.autograd.grad(self.dfdxt_g, self.g, x, retain_graph=True) return TensorList( torch.autograd.grad(self.f0, self.x, dfdx_x, retain_graph=True)) def ip(self, a, b): return self.problem.ip_input(a, b) def M1(self, x): return self.problem.M1(x) def M2(self, x): return self.problem.M2(x)
def stride(self): return TensorList([ s * self.layer_stride[l] for l, s in zip(self.output_layers, self.pool_stride) ])
def dim(self): return TensorList([self.layer_dim[l] for l in self.output_layers])
def ip_input(self, a: TensorList, b: TensorList): return a.reshape(-1) @ b.reshape(-1)
def A(self, x): return TensorList( torch.autograd.grad(self.g, self.x, x, retain_graph=True)) + self.hessian_reg * x
def forward(self, train_imgs, test_imgs, train_masks, test_masks, num_refinement_iter=2): num_sequences = train_imgs.shape[1] num_train_frames = train_imgs.shape[0] num_test_frames = test_imgs.shape[0] # Extract backbone features train_feat_backbone = self.extract_backbone_features( train_imgs.view(-1, train_imgs.shape[-3], train_imgs.shape[-2], train_imgs.shape[-1])) test_feat_backbone = self.extract_backbone_features( test_imgs.contiguous().view(-1, test_imgs.shape[-3], test_imgs.shape[-2], test_imgs.shape[-1])) # Extract features input to the target model train_feat_tm = self.extract_target_model_features(train_feat_backbone) # seq*frames, channels, height, width train_feat_tm = train_feat_tm.view(num_train_frames, num_sequences, *train_feat_tm.shape[-3:]) train_feat_tm_all = [train_feat_tm, ] # Get few-shot learner label and spatial importance weights # only use train_masks, no need train_feat_tm few_shot_label, few_shot_sw = self.label_encoder(train_masks) few_shot_label_all = [few_shot_label, ] few_shot_sw_all = None if few_shot_sw is None else [few_shot_sw, ] test_feat_tm = self.extract_target_model_features(test_feat_backbone) # seq*frames, channels, height, width # Obtain the target module parameters using the few-shot learner filter, filter_iter, _ = self.target_model.get_filter(train_feat_tm, few_shot_label, few_shot_sw) mask_predictons_all = [] # Iterate over the test sequence for i in range(num_test_frames): # Features for the current frame test_feat_tm_it = test_feat_tm.view(num_test_frames, num_sequences, *test_feat_tm.shape[-3:])[i:i + 1, ...] # Apply the target model to obtain mask encodings. mask_encoding_pred = [self.target_model.apply_target_model(f, test_feat_tm_it) for f in filter_iter] test_feat_backbone_it = {k: v.view(num_test_frames, num_sequences, *v.shape[-3:])[i, ...] for k, v in test_feat_backbone.items()} mask_encoding_pred_last_iter = mask_encoding_pred[-1] # Run decoder to obtain the segmentation mask mask_pred, decoder_feat = self.decoder(mask_encoding_pred_last_iter, test_feat_backbone_it, test_imgs.shape[-2:]) mask_pred = mask_pred.view(1, num_sequences, *mask_pred.shape[-2:]) mask_predictons_all.append(mask_pred) # Convert the segmentation scores to probability mask_pred_prob = torch.sigmoid(mask_pred.clone().detach()) # Obtain label encoding for the predicted mask in the previous frame # no need test_feat_tm_it few_shot_label, few_shot_sw = self.label_encoder(mask_pred_prob) # Extend the training data using the predicted mask few_shot_label_all.append(few_shot_label) if few_shot_sw_all is not None: few_shot_sw_all.append(few_shot_sw) train_feat_tm_all.append(test_feat_tm_it) # Update the target model using the extended training set if (i < (num_test_frames - 1)) and (num_refinement_iter > 0): train_feat_tm_it = torch.cat(train_feat_tm_all, dim=0) few_shot_label_it = torch.cat(few_shot_label_all, dim=0) if few_shot_sw_all is not None: few_shot_sw_it = torch.cat(few_shot_sw_all, dim=0) else: few_shot_sw_it = None # Run few-shot learner to update the target model filter_updated, _, _ = self.target_model.filter_optimizer(TensorList([filter]), feat=train_feat_tm_it, label=few_shot_label_it, sample_weight=few_shot_sw_it, num_iter=num_refinement_iter) filter = filter_updated[0] # filter_updated is a TensorList mask_predictons_all = torch.cat(mask_predictons_all, dim=0) return mask_predictons_all
class NewtonCG(ConjugateGradientBase): """Newton with Conjugate Gradient. Handels general minimization problems.""" def __init__(self, problem: MinimizationProblem, variable: TensorList, init_hessian_reg=0.0, hessian_reg_factor=1.0, cg_eps=0.0, fletcher_reeves=True, standard_alpha=True, direction_forget_factor=0, debug=False, analyze=False, plotting=False, fig_num=(10, 11, 12)): super().__init__(fletcher_reeves, standard_alpha, direction_forget_factor, debug or analyze or plotting) self.problem = problem self.x = variable self.analyze_convergence = analyze self.plotting = plotting self.fig_num = fig_num self.hessian_reg = init_hessian_reg self.hessian_reg_factor = hessian_reg_factor self.cg_eps = cg_eps self.f0 = None self.g = None self.residuals = torch.zeros(0) self.losses = torch.zeros(0) self.gradient_mags = torch.zeros(0) def clear_temp(self): self.f0 = None self.g = None def run(self, num_cg_iter, num_newton_iter=None): if isinstance(num_cg_iter, int): if num_cg_iter == 0: return if num_newton_iter is None: num_newton_iter = 1 num_cg_iter = [num_cg_iter] * num_newton_iter num_newton_iter = len(num_cg_iter) if num_newton_iter == 0: return if self.analyze_convergence: self.evaluate_CG_iteration(0) for cg_iter in num_cg_iter: self.run_newton_iter(cg_iter) self.hessian_reg *= self.hessian_reg_factor if self.debug: if not self.analyze_convergence: loss = self.problem(self.x) self.losses = torch.cat( (self.losses, loss.detach().cpu().view(-1))) if self.plotting: plot_graph(self.losses, self.fig_num[0], title='Loss') plot_graph(self.residuals, self.fig_num[1], title='CG residuals') if self.analyze_convergence: plot_graph(self.gradient_mags, self.fig_num[2], 'Gradient magnitude') self.x.detach_() self.clear_temp() return self.losses, self.residuals def run_newton_iter(self, num_cg_iter): self.x.requires_grad_(True) # Evaluate function at current estimate self.f0 = self.problem(self.x) if self.debug and not self.analyze_convergence: self.losses = torch.cat( (self.losses, self.f0.detach().cpu().view(-1))) # Gradient of loss self.g = TensorList( torch.autograd.grad(self.f0, self.x, create_graph=True)) # Get the right hand side self.b = -self.g.detach() # Run CG delta_x, res = self.run_CG(num_cg_iter, eps=self.cg_eps) self.x.detach_() self.x += delta_x if self.debug: self.residuals = torch.cat((self.residuals, res)) def A(self, x): return TensorList( torch.autograd.grad(self.g, self.x, x, retain_graph=True)) + self.hessian_reg * x def ip(self, a, b): # Implements the inner product return self.problem.ip_input(a, b) def M1(self, x): return self.problem.M1(x) def M2(self, x): return self.problem.M2(x) def evaluate_CG_iteration(self, delta_x): if self.analyze_convergence: x = (self.x + delta_x).detach() x.requires_grad_(True) # compute loss and gradient loss = self.problem(x) grad = TensorList(torch.autograd.grad(loss, x)) # store in the vectors self.losses = torch.cat( (self.losses, loss.detach().cpu().view(-1))) self.gradient_mags = torch.cat( (self.gradient_mags, sum(grad.view(-1) @ grad.view(-1)).cpu().sqrt().detach().view(-1)))
class GaussNewtonCG(ConjugateGradientBase): """Gauss-Newton with Conjugate Gradient optimizer.""" def __init__(self, problem: L2Problem, variable: TensorList, cg_eps=0.0, fletcher_reeves=True, standard_alpha=True, direction_forget_factor=0, debug=False, analyze=False, plotting=False, visdom=None): super().__init__(fletcher_reeves, standard_alpha, direction_forget_factor, debug or analyze or plotting) self.problem = problem self.x = variable self.analyze_convergence = analyze self.plotting = plotting self.fig_num = (10, 11, 12) self.visdom = visdom self.cg_eps = cg_eps self.f0 = None self.g = None self.dfdxt_g = None self.residuals = torch.zeros(0) self.losses = torch.zeros(0) self.gradient_mags = torch.zeros(0) def clear_temp(self): self.f0 = None self.g = None self.dfdxt_g = None def run_GN(self, *args, **kwargs): return self.run(*args, **kwargs) def run(self, num_cg_iter, num_gn_iter=None): """Run the optimizer. args: num_cg_iter: Number of CG iterations per GN iter. If list, then each entry specifies number of CG iterations and number of GN iterations is given by the length of the list. num_gn_iter: Number of GN iterations. Shall only be given if num_cg_iter is an integer. """ if isinstance(num_cg_iter, int): if num_gn_iter is None: raise ValueError( 'Must specify number of GN iter if CG iter is constant') num_cg_iter = [num_cg_iter] * num_gn_iter num_gn_iter = len(num_cg_iter) if num_gn_iter == 0: return if self.analyze_convergence: self.evaluate_CG_iteration(0) # Outer loop for running the GN iterations. for cg_iter in num_cg_iter: self.run_GN_iter(cg_iter) if self.debug: if not self.analyze_convergence: self.f0 = self.problem(self.x) loss = self.problem.ip_output(self.f0, self.f0) self.losses = torch.cat( (self.losses, loss.detach().cpu().view(-1))) if self.visdom is not None: self.visdom.register(self.losses, 'lineplot', 3, 'Loss') self.visdom.register(self.residuals, 'lineplot', 3, 'CG residuals') if self.analyze_convergence: self.visdom.register(self.gradient_mags, 'lineplot', 4, 'Gradient magnitude') elif self.plotting: plot_graph(self.losses, self.fig_num[0], title='Loss') plot_graph(self.residuals, self.fig_num[1], title='CG residuals') if self.analyze_convergence: plot_graph(self.gradient_mags, self.fig_num[2], 'Gradient magnitude') self.x.detach_() self.clear_temp() return self.losses, self.residuals def run_GN_iter(self, num_cg_iter): """Runs a single GN iteration.""" self.x.requires_grad_(True) # Evaluate function at current estimate self.f0 = self.problem(self.x) # Create copy with graph detached self.g = self.f0.detach() if self.debug and not self.analyze_convergence: loss = self.problem.ip_output(self.g, self.g) # print('Loss:',loss) self.losses = torch.cat( (self.losses, loss.detach().cpu().view(-1))) self.g.requires_grad_(True) # Get df/dx^t @ f0 self.dfdxt_g = TensorList( torch.autograd.grad(self.f0, self.x, self.g, create_graph=True)) # Get the right hand side self.b = -self.dfdxt_g.detach() # Run CG delta_x, res = self.run_CG(num_cg_iter, eps=self.cg_eps) self.x.detach_() self.x += delta_x if self.debug: self.residuals = torch.cat((self.residuals, res)) def A(self, x): dfdx_x = torch.autograd.grad(self.dfdxt_g, self.g, x, retain_graph=True) return TensorList( torch.autograd.grad(self.f0, self.x, dfdx_x, retain_graph=True)) def ip(self, a, b): return self.problem.ip_input(a, b) def M1(self, x): return self.problem.M1(x) def M2(self, x): return self.problem.M2(x) def evaluate_CG_iteration(self, delta_x): if self.analyze_convergence: x = (self.x + delta_x).detach() x.requires_grad_(True) # compute loss and gradient f = self.problem(x) loss = self.problem.ip_output(f, f) grad = TensorList(torch.autograd.grad(loss, x)) # store in the vectors self.losses = torch.cat( (self.losses, loss.detach().cpu().view(-1))) self.gradient_mags = torch.cat( (self.gradient_mags, sum(grad.view(-1) @ grad.view(-1)).cpu().sqrt().detach().view(-1)))