Exemple #1
0
    def extract(self, im: np.ndarray, debug_save_name=None):
        with fluid.dygraph.guard():
            if debug_save_name is not None:
                np.savez(debug_save_name, im)

            im = im / 255.  # don't use im /= 255. since we don't want to alter the input
            im -= self.mean
            im /= self.std
            im = n2p(im)

            output_features = self.net.extract_features(im, self.feature_layers)

            # Store the raw resnet features which are input to iounet
            iounet_backbone_features = TensorList([
                output_features[layer] for layer in self.iounet_feature_layers
            ])
            self.iounet_backbone_features = iounet_backbone_features.numpy()

            # Store the processed features from iounet, just before pooling
            self.iounet_features = TensorList([
                f.numpy()
                for f in self.iou_predictor.get_iou_feat(
                    iounet_backbone_features)
            ])

            output = TensorList([
                output_features[layer].numpy() for layer in self.output_layers
            ])
            return output
Exemple #2
0
    def extract(self, im: np.ndarray, debug_save_name=None):
        with fluid.dygraph.guard():
            if debug_save_name is not None:
                np.savez(debug_save_name, im)

            im = im / 255.  # don't use im /= 255. since we don't want to alter the input
            im -= self.mean
            im /= self.std
            im = n2p(im)

            output_features = self.net.extract_features(
                im, self.feature_layers)

            # Store the raw backbone features which are input to estimator
            estimator_backbone_features = TensorList([
                output_features[layer]
                for layer in self.estimator_feature_layers
            ])
            self.estimator_backbone_features = estimator_backbone_features.numpy(
            )

            output = TensorList([
                output_features[layer].numpy() for layer in self.output_layers
            ])
            return output
Exemple #3
0
    def run_newton_iter(self, num_cg_iter):

        self.x.requires_grad_(True)

        # Evaluate function at current estimate
        self.f0 = self.problem(self.x)

        if self.debug and not self.analyze_convergence:
            self.losses = torch.cat(
                (self.losses, self.f0.detach().cpu().view(-1)))

        # Gradient of loss
        self.g = TensorList(
            torch.autograd.grad(self.f0, self.x, create_graph=True))

        # Get the right hand side
        self.b = -self.g.detach()

        # Run CG
        delta_x, res = self.run_CG(num_cg_iter, eps=self.cg_eps)

        self.x.detach_()
        self.x += delta_x

        if self.debug:
            self.residuals = torch.cat((self.residuals, res))
Exemple #4
0
    def run_GN_iter(self, num_cg_iter):
        """Runs a single GN iteration."""

        self.x.requires_grad_(True)

        # Evaluate function at current estimate
        self.f0 = self.problem(self.x)

        # Create copy with graph detached
        self.g = self.f0.detach()

        if self.debug and not self.analyze_convergence:
            loss = self.problem.ip_output(self.g, self.g)
            # print('Loss:',loss)
            self.losses = torch.cat(
                (self.losses, loss.detach().cpu().view(-1)))

        self.g.requires_grad_(True)

        # Get df/dx^t @ f0
        self.dfdxt_g = TensorList(
            torch.autograd.grad(self.f0, self.x, self.g, create_graph=True))

        # Get the right hand side
        self.b = -self.dfdxt_g.detach()

        # Run CG
        delta_x, res = self.run_CG(num_cg_iter, eps=self.cg_eps)

        self.x.detach_()
        self.x += delta_x

        if self.debug:
            self.residuals = torch.cat((self.residuals, res))
Exemple #5
0
 def size(self, im_sz):
     if self.output_size is None:
         return TensorList([floordiv(im_sz, s) for s in self.stride()])
     if isinstance(im_sz, PTensor):
         return TensorList([
             floordiv(im_sz, s) if sz is None else np.array([sz[0], sz[1]])
             for sz, s in zip(self.output_size, self.stride())
         ])
Exemple #6
0
    def run(self, num_cg_iter, joint_var):
        """Run the oprimizer with the provided number of iterations."""

        if num_cg_iter == 0:
            return

        self.x = joint_var
        lossvec = None
        if self.debug:
            lossvec = torch.zeros(2)

        self.x.requires_grad_(True)

        # Evaluate function at current estimate
        self.f0 = self.problem(self.x)

        # Create copy with graph detached
        self.g = self.f0.detach()

        if self.debug:
            lossvec[0] = self.problem.ip_output(self.g, self.g)

        self.g.requires_grad_(True)

        # Get df/dx^t @ f0
        self.dfdxt_g = TensorList(
            torch.autograd.grad(self.f0, self.x, self.g, create_graph=True))

        # Get the right hand side
        self.b = -self.dfdxt_g.detach()

        # Run CG
        delta_x, res = self.run_CG(num_cg_iter, eps=self.cg_eps)

        self.x.detach_()
        self.x += delta_x

        if self.debug:
            self.f0 = self.problem(self.x)
            lossvec[-1] = self.problem.ip_output(self.f0, self.f0)
            self.residuals = torch.cat((self.residuals, res))
            self.losses = torch.cat((self.losses, lossvec))
            # print('Loss:', self.losses)
            if self.visdom is not None:
                self.visdom.register(self.losses, 'lineplot', 3, 'Loss')
                self.visdom.register(self.residuals, 'lineplot', 3,
                                     'CG residuals')
            elif self.plotting:
                plot_graph(self.losses, self.fig_num[0], title='Loss')
                plot_graph(self.residuals,
                           self.fig_num[1],
                           title='CG residuals')

        self.x.detach_()
        self.clear_temp()
Exemple #7
0
 def get_attribute(self, name: str, ignore_missing: bool=False):
     if ignore_missing:
         return TensorList([
             getattr(f, name) for f in self.features
             if self._return_feature(f) and hasattr(f, name)
         ])
     else:
         return TensorList([
             getattr(f, name, None) for f in self.features
             if self._return_feature(f)
         ])
Exemple #8
0
    def evaluate_CG_iteration(self, delta_x):
        if self.analyze_convergence:
            x = (self.x + delta_x).detach()
            x.requires_grad_(True)

            # compute loss and gradient
            loss = self.problem(x)
            grad = TensorList(torch.autograd.grad(loss, x))

            # store in the vectors
            self.losses = torch.cat((self.losses, loss.detach().cpu().view(-1)))
            self.gradient_mags = torch.cat((self.gradient_mags, sum(grad.view(-1) @ grad.view(-1)).cpu().sqrt().detach().view(-1)))
Exemple #9
0
    def run(self, num_iter, dummy=None):

        if num_iter == 0:
            return

        lossvec = None
        if self.debug:
            lossvec = torch.zeros(num_iter + 1)
            grad_mags = torch.zeros(num_iter + 1)

        for i in range(num_iter):
            self.x.requires_grad_(True)

            # Evaluate function at current estimate
            self.f0 = self.problem(self.x)

            # Compute loss
            loss = self.problem.ip_output(self.f0, self.f0)

            # Compute grad
            grad = TensorList(torch.autograd.grad(loss, self.x))

            # Update direction
            if self.dir is None:
                self.dir = grad
            else:
                self.dir = grad + self.momentum * self.dir

            self.x.detach_()
            self.x -= self.step_legnth * self.dir

            if self.debug:
                lossvec[i] = loss.item()
                grad_mags[i] = sum(grad.view(-1) @ grad.view(-1)).sqrt().item()

        if self.debug:
            self.x.requires_grad_(True)
            self.f0 = self.problem(self.x)
            loss = self.problem.ip_output(self.f0, self.f0)
            grad = TensorList(torch.autograd.grad(loss, self.x))
            lossvec[-1] = self.problem.ip_output(self.f0, self.f0).item()
            grad_mags[-1] = sum(
                grad.view(-1) @ grad.view(-1)).cpu().sqrt().item()
            self.losses = torch.cat((self.losses, lossvec))
            self.gradient_mags = torch.cat((self.gradient_mags, grad_mags))

            if self.visdom is not None:
                self.visdom.register(self.losses, 'lineplot', 3, 'Loss')
                self.visdom.register(self.gradient_mags, 'lineplot', 4,
                                     'Gradient magnitude')
            elif self.plotting:
                plot_graph(self.losses, self.fig_num[0], title='Loss')
                plot_graph(self.gradient_mags,
                           self.fig_num[1],
                           title='Gradient magnitude')

        self.x.detach_()
        self.clear_temp()
Exemple #10
0
 def A(self, x):
     dfdx_x = torch.autograd.grad(self.dfdxt_g,
                                  self.g,
                                  x,
                                  retain_graph=True)
     return TensorList(
         torch.autograd.grad(self.f0, self.x, dfdx_x, retain_graph=True))
Exemple #11
0
 def get_fparams(self, name: str=None):
     if name is None:
         return [f.fparams for f in self.features if self._return_feature(f)]
     return TensorList([
         getattr(f.fparams, name) for f in self.features
         if self._return_feature(f)
     ]).unroll()
Exemple #12
0
    def extract_transformed(self,
                            im,
                            pos,
                            scale,
                            image_sz,
                            transforms,
                            debug_save_name=None):
        """Extract features from a set of transformed image samples.
        args:
            im: Image.
            pos: Center position for extraction.
            scale: Image scale to extract features from.
            image_sz: Size to resize the image samples to before extraction.
            transforms: A set of image transforms to apply.
        """

        # Get image patche
        im_patch = sample_patch(im, pos, scale * image_sz, image_sz)

        # Apply transforms
        with fluid.dygraph.guard(fluid.CPUPlace()):
            im_patches = np.stack([T(im_patch) for T in transforms])

        if debug_save_name is not None:
            np.save(debug_save_name, im_patches)

        im_patches = np.transpose(im_patches, (0, 3, 1, 2))

        # Compute features
        feature_map = TensorList(
            [f.get_feature(im_patches) for f in self.features]).unroll()

        return feature_map
Exemple #13
0
    def forward(self,
                meta_parameter: TensorList,
                feat,
                label,
                sample_weight=None):
        # Assumes multiple filters, i.e.  (sequences, filters, feat_dim, fH, fW)
        filter = meta_parameter[0]

        num_images = feat.shape[0]
        num_sequences = feat.shape[1] if feat.dim() == 5 else 1

        # Compute scores
        scores = filter_layer.apply_filter(
            feat, filter, dilation_factors=self.filter_dilation_factors)

        if sample_weight is None:
            sample_weight = math.sqrt(1.0 / num_images)
        elif isinstance(sample_weight, torch.Tensor):
            if sample_weight.numel() == scores.numel():
                sample_weight = sample_weight.view(scores.shape)
            elif sample_weight.dim() == 1:
                sample_weight = sample_weight.view(-1, 1, 1, 1, 1)

        label = label.view(scores.shape)

        data_residual = sample_weight * (scores - label)

        # Compute regularization residual. Put batch in second dimension
        reg_residual = self.filter_reg * filter.view(1, num_sequences, -1)

        return TensorList([data_residual, reg_residual])
Exemple #14
0
    def extract(self, im, pos, scales, image_sz, debug_save_name=None):
        """Extract features.
        args:
            im: Image.
            pos: Center position for extraction.
            scales: Image scales to extract features from.
            image_sz: Size to resize the image samples to before extraction.
        """
        if isinstance(scales, (int, float)):
            scales = [scales]

        # Get image patches
        with fluid.dygraph.guard(fluid.CPUPlace()):
            im_patches = np.stack([
                sample_patch(im, pos, s * image_sz, image_sz) for s in scales
            ])

        if debug_save_name is not None:
            np.save(debug_save_name, im_patches)

        im_patches = np.transpose(im_patches, (0, 3, 1, 2))

        # Compute features
        feature_map = TensorList(
            [f.get_feature(im_patches) for f in self.features]).unroll()

        return feature_map
Exemple #15
0
def ltr_collate_stack1(batch):
    """Puts each data field into a tensor. The tensors are stacked at dim=1 to form the batch"""

    error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
    elem_type = type(batch[0])
    if isinstance(batch[0], torch.Tensor):
        out = None
        if _check_use_shared_memory():
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = batch[0].storage()._new_shared(numel)
            out = batch[0].new(storage)
        return torch.stack(batch, 1, out=out)
        # if batch[0].dim() < 4:
        #     return torch.stack(batch, 0, out=out)
        # return torch.cat(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        elem = batch[0]
        if elem_type.__name__ == 'ndarray':
            # array of string classes and object
            if torch.utils.data.dataloader.re.search(
                    '[SaUO]', elem.dtype.str) is not None:
                raise TypeError(error_msg.format(elem.dtype))

            return torch.stack([torch.from_numpy(b) for b in batch], 1)
        if elem.shape == ():  # scalars
            py_type = float if elem.dtype.name.startswith('float') else int
            return torch.utils.data.dataloader.numpy_type_map[elem.dtype.name](
                list(map(py_type, batch)))
    elif isinstance(batch[0], int_classes):
        return torch.LongTensor(batch)
    elif isinstance(batch[0], float):
        return torch.DoubleTensor(batch)
    elif isinstance(batch[0], string_classes):
        return batch
    elif isinstance(batch[0], TensorDict):
        return TensorDict({
            key: ltr_collate_stack1([d[key] for d in batch])
            for key in batch[0]
        })
    elif isinstance(batch[0], collections.Mapping):
        return {
            key: ltr_collate_stack1([d[key] for d in batch])
            for key in batch[0]
        }
    elif isinstance(batch[0], TensorList):
        transposed = zip(*batch)
        return TensorList(
            [ltr_collate_stack1(samples) for samples in transposed])
    elif isinstance(batch[0], collections.Sequence):
        transposed = zip(*batch)
        return [ltr_collate_stack1(samples) for samples in transposed]
    elif batch[0] is None:
        return batch

    raise TypeError((error_msg.format(type(batch[0]))))
Exemple #16
0
    def __init__(self, *args, **kwargs):
        if len(args) > 0:
            raise ValueError

        for name, val in kwargs.items():
            if isinstance(val, list):
                setattr(self, name, TensorList(val))
            else:
                setattr(self, name, val)
Exemple #17
0
    def get_inputs(self, scope=''):
        if scope not in self.inputs_dict:
            name = scope + "training_samples"
            vars = create_var_list(name, self.sample_weights, [None])
            training_samples_p = TensorList(vars)

            name = scope + "y"
            vars = create_var_list(name, self.y, [None])
            y_p = TensorList(vars)

            name = scope + "sample_weights"
            vars = create_var_list(name, self.sample_weights, [None, 1])
            sample_weights_p = TensorList(vars)

            self.inputs_dict[scope] = (training_samples_p, y_p,
                                       sample_weights_p)

        return self.inputs_dict[scope]
Exemple #18
0
    def extract(self, im: np.ndarray, debug_save_name=None):
        with fluid.dygraph.guard():
            if debug_save_name is not None:
                np.savez(debug_save_name, im)

            im = n2p(im)

            output_features = self.net.extract_backbone_features(im)

            # Store the raw backbone features which are input to estimator
            output = TensorList([layer.numpy() for layer in output_features])
            return output
Exemple #19
0
    def get_feature(self, im: np.ndarray):
        """Get the feature. Generally, call this function.
        args:
            im: image patch
        """

        # Return empty tensor if it should not be used
        is_color = im.shape[1] == 3
        if is_color and not self.use_for_color or not is_color and not self.use_for_gray:
            return np.array([])

        feat_list = self.extract(im)

        output_sz = [None] * len(
            feat_list) if self.output_size is None else self.output_size

        # Pool/downsample
        with fluid.dygraph.guard():
            feat_list = [n2p(f) for f in feat_list]

            for i, (sz, s) in enumerate(zip(output_sz, self.pool_stride)):
                if sz is not None:
                    feat_list[i] = layers.adaptive_pool2d(feat_list[i],
                                                          sz,
                                                          pool_type='avg')
                elif s != 1:
                    feat_list[i] = layers.pool2d(feat_list[i],
                                                 s,
                                                 pool_stride=s,
                                                 pool_type='avg')

            # Normalize
            if self.normalize_power is not None:
                new_feat_list = []
                for feat in feat_list:
                    norm = (layers.reduce_sum(layers.reshape(
                        layers.abs(feat), [feat.shape[0], 1, 1, -1])**
                                              self.normalize_power,
                                              dim=3,
                                              keep_dim=True) /
                            (feat.shape[1] * feat.shape[2] * feat.shape[3]) +
                            1e-10)**(1 / self.normalize_power)
                    feat = broadcast_op(feat, norm, 'div')
                    new_feat_list.append(feat)
                feat_list = new_feat_list

            # To numpy
            feat_list = TensorList([f.numpy() for f in feat_list])
        return feat_list
Exemple #20
0
    def extract(self, im, pos, scales, image_sz):
        if isinstance(scales, (int, float)):
            scales = [scales]

        # Get image patches
        im_patches = np.stack(
            [sample_patch(im, pos, s * image_sz, image_sz) for s in scales])
        im_patches = np.transpose(im_patches, (0, 3, 1, 2))

        # Compute features
        feature_map = layers.concat(TensorList(
            [f.get_feature(im_patches) for f in self.features]).unroll(),
                                    axis=1)

        return feature_map
Exemple #21
0
    def __call__(self, x: TensorList, scope=''):
        """
        Compute residuals
        :param x: [filters]
        :return: [data_terms, filter_regularizations]
        """
        training_samples, y, samples_weights = self.get_inputs(scope)
        # Do convolution and compute residuals
        residuals = operation.conv2d(training_samples, x, mode='same').apply(
            self.response_activation)
        residuals = residuals - y

        residuals = residuals * samples_weights.sqrt()

        # Add regularization for projection matrix
        residuals.extend(
            x.apply(static_identity) * self.filter_reg.apply(math.sqrt))

        return residuals
Exemple #22
0
    def get_filter(self,
                   feat,
                   train_label,
                   train_sw,
                   num_objects=None,
                   *args,
                   **kwargs):
        """
        Get the initial target model parameters given the few-shot labels
        :param feat: [1, 5, 512, 30, 52]
        :param train_label: [1, 5, 16, 30, 52]
        :return:
        """

        # print("num_objects: {}".format(num_objects))  # training part is None
        if num_objects is None:
            weights = self.filter_initializer(feat, train_label)
        else:
            weights = self.filter_initializer(feat, train_label)
            weights = weights.repeat(1, num_objects, 1, 1, 1)

        # print("weights size(): {}".format(weights.size()))  # [5, 16, 512, 3, 3]
        if self.filter_optimizer is not None:
            weights, weights_iter, losses = self.filter_optimizer(
                TensorList([weights]),
                feat=feat,
                label=train_label,
                sample_weight=train_sw,
                *args,
                **kwargs)
            weights = weights[0]
            weights_iter = [w[0] for w in weights_iter]
        else:
            weights_iter = [weights]
            losses = None

        return weights, weights_iter, losses
Exemple #23
0
class ConjugateGradient_Attn(ConjugateGradientBase):
    """Conjugate Gradient optimizer, performing single linearization of the residuals in the start."""
    def __init__(self,
                 problem: L2Problem,
                 variable: TensorList,
                 cg_eps=0.0,
                 fletcher_reeves=True,
                 standard_alpha=True,
                 direction_forget_factor=0,
                 debug=False,
                 plotting=False,
                 visdom=None):
        super().__init__(fletcher_reeves, standard_alpha,
                         direction_forget_factor, debug or plotting)

        self.problem = problem
        self.x = variable

        self.plotting = plotting
        self.fig_num = (10, 11)
        self.visdom = visdom

        self.cg_eps = cg_eps
        self.f0 = None
        self.g = None
        self.dfdxt_g = None

        self.residuals = torch.zeros(0)
        self.losses = torch.zeros(0)

    def clear_temp(self):
        self.f0 = None
        self.g = None
        self.dfdxt_g = None

    def run(self, num_cg_iter, joint_var):
        """Run the oprimizer with the provided number of iterations."""

        if num_cg_iter == 0:
            return

        self.x = joint_var
        lossvec = None
        if self.debug:
            lossvec = torch.zeros(2)

        self.x.requires_grad_(True)

        # Evaluate function at current estimate
        self.f0 = self.problem(self.x)

        # Create copy with graph detached
        self.g = self.f0.detach()

        if self.debug:
            lossvec[0] = self.problem.ip_output(self.g, self.g)

        self.g.requires_grad_(True)

        # Get df/dx^t @ f0
        self.dfdxt_g = TensorList(
            torch.autograd.grad(self.f0, self.x, self.g, create_graph=True))

        # Get the right hand side
        self.b = -self.dfdxt_g.detach()

        # Run CG
        delta_x, res = self.run_CG(num_cg_iter, eps=self.cg_eps)

        self.x.detach_()
        self.x += delta_x

        if self.debug:
            self.f0 = self.problem(self.x)
            lossvec[-1] = self.problem.ip_output(self.f0, self.f0)
            self.residuals = torch.cat((self.residuals, res))
            self.losses = torch.cat((self.losses, lossvec))
            # print('Loss:', self.losses)
            if self.visdom is not None:
                self.visdom.register(self.losses, 'lineplot', 3, 'Loss')
                self.visdom.register(self.residuals, 'lineplot', 3,
                                     'CG residuals')
            elif self.plotting:
                plot_graph(self.losses, self.fig_num[0], title='Loss')
                plot_graph(self.residuals,
                           self.fig_num[1],
                           title='CG residuals')

        self.x.detach_()
        self.clear_temp()

    def A(self, x):
        dfdx_x = torch.autograd.grad(self.dfdxt_g,
                                     self.g,
                                     x,
                                     retain_graph=True)
        return TensorList(
            torch.autograd.grad(self.f0, self.x, dfdx_x, retain_graph=True))

    def ip(self, a, b):
        return self.problem.ip_input(a, b)

    def M1(self, x):
        return self.problem.M1(x)

    def M2(self, x):
        return self.problem.M2(x)
Exemple #24
0
 def stride(self):
     return TensorList([
         s * self.layer_stride[l]
         for l, s in zip(self.output_layers, self.pool_stride)
     ])
Exemple #25
0
 def dim(self):
     return TensorList([self.layer_dim[l] for l in self.output_layers])
Exemple #26
0
 def ip_input(self, a: TensorList, b: TensorList):
     return a.reshape(-1) @ b.reshape(-1)
Exemple #27
0
 def A(self, x):
     return TensorList(
         torch.autograd.grad(self.g, self.x, x,
                             retain_graph=True)) + self.hessian_reg * x
Exemple #28
0
    def forward(self, train_imgs, test_imgs, train_masks, test_masks, num_refinement_iter=2):
        num_sequences = train_imgs.shape[1]
        num_train_frames = train_imgs.shape[0]
        num_test_frames = test_imgs.shape[0]

        # Extract backbone features
        train_feat_backbone = self.extract_backbone_features(
            train_imgs.view(-1, train_imgs.shape[-3], train_imgs.shape[-2], train_imgs.shape[-1]))
        test_feat_backbone = self.extract_backbone_features(
            test_imgs.contiguous().view(-1, test_imgs.shape[-3], test_imgs.shape[-2], test_imgs.shape[-1]))

        # Extract features input to the target model
        train_feat_tm = self.extract_target_model_features(train_feat_backbone)  # seq*frames, channels, height, width
        train_feat_tm = train_feat_tm.view(num_train_frames, num_sequences, *train_feat_tm.shape[-3:])

        train_feat_tm_all = [train_feat_tm, ]

        # Get few-shot learner label and spatial importance weights
        # only use train_masks, no need train_feat_tm
        few_shot_label, few_shot_sw = self.label_encoder(train_masks)

        few_shot_label_all = [few_shot_label, ]
        few_shot_sw_all = None if few_shot_sw is None else [few_shot_sw, ]

        test_feat_tm = self.extract_target_model_features(test_feat_backbone)  # seq*frames, channels, height, width

        # Obtain the target module parameters using the few-shot learner
        filter, filter_iter, _ = self.target_model.get_filter(train_feat_tm, few_shot_label, few_shot_sw)

        mask_predictons_all = []

        # Iterate over the test sequence
        for i in range(num_test_frames):
            # Features for the current frame
            test_feat_tm_it = test_feat_tm.view(num_test_frames, num_sequences, *test_feat_tm.shape[-3:])[i:i + 1, ...]

            # Apply the target model to obtain mask encodings.
            mask_encoding_pred = [self.target_model.apply_target_model(f, test_feat_tm_it) for f in filter_iter]

            test_feat_backbone_it = {k: v.view(num_test_frames, num_sequences, *v.shape[-3:])[i, ...] for k, v in
                                     test_feat_backbone.items()}
            mask_encoding_pred_last_iter = mask_encoding_pred[-1]

            # Run decoder to obtain the segmentation mask
            mask_pred, decoder_feat = self.decoder(mask_encoding_pred_last_iter, test_feat_backbone_it,
                                                   test_imgs.shape[-2:])
            mask_pred = mask_pred.view(1, num_sequences, *mask_pred.shape[-2:])

            mask_predictons_all.append(mask_pred)

            # Convert the segmentation scores to probability
            mask_pred_prob = torch.sigmoid(mask_pred.clone().detach())

            # Obtain label encoding for the predicted mask in the previous frame
            # no need test_feat_tm_it
            few_shot_label, few_shot_sw = self.label_encoder(mask_pred_prob)

            # Extend the training data using the predicted mask
            few_shot_label_all.append(few_shot_label)
            if few_shot_sw_all is not None:
                few_shot_sw_all.append(few_shot_sw)

            train_feat_tm_all.append(test_feat_tm_it)

            # Update the target model using the extended training set
            if (i < (num_test_frames - 1)) and (num_refinement_iter > 0):
                train_feat_tm_it = torch.cat(train_feat_tm_all, dim=0)
                few_shot_label_it = torch.cat(few_shot_label_all, dim=0)

                if few_shot_sw_all is not None:
                    few_shot_sw_it = torch.cat(few_shot_sw_all, dim=0)
                else:
                    few_shot_sw_it = None

                # Run few-shot learner to update the target model
                filter_updated, _, _ = self.target_model.filter_optimizer(TensorList([filter]),
                                                                          feat=train_feat_tm_it,
                                                                          label=few_shot_label_it,
                                                                          sample_weight=few_shot_sw_it,
                                                                          num_iter=num_refinement_iter)

                filter = filter_updated[0]  # filter_updated is a TensorList

        mask_predictons_all = torch.cat(mask_predictons_all, dim=0)
        return mask_predictons_all
Exemple #29
0
class NewtonCG(ConjugateGradientBase):
    """Newton with Conjugate Gradient. Handels general minimization problems."""
    def __init__(self,
                 problem: MinimizationProblem,
                 variable: TensorList,
                 init_hessian_reg=0.0,
                 hessian_reg_factor=1.0,
                 cg_eps=0.0,
                 fletcher_reeves=True,
                 standard_alpha=True,
                 direction_forget_factor=0,
                 debug=False,
                 analyze=False,
                 plotting=False,
                 fig_num=(10, 11, 12)):
        super().__init__(fletcher_reeves, standard_alpha,
                         direction_forget_factor, debug or analyze or plotting)

        self.problem = problem
        self.x = variable

        self.analyze_convergence = analyze
        self.plotting = plotting
        self.fig_num = fig_num

        self.hessian_reg = init_hessian_reg
        self.hessian_reg_factor = hessian_reg_factor
        self.cg_eps = cg_eps
        self.f0 = None
        self.g = None

        self.residuals = torch.zeros(0)
        self.losses = torch.zeros(0)
        self.gradient_mags = torch.zeros(0)

    def clear_temp(self):
        self.f0 = None
        self.g = None

    def run(self, num_cg_iter, num_newton_iter=None):

        if isinstance(num_cg_iter, int):
            if num_cg_iter == 0:
                return
            if num_newton_iter is None:
                num_newton_iter = 1
            num_cg_iter = [num_cg_iter] * num_newton_iter

        num_newton_iter = len(num_cg_iter)
        if num_newton_iter == 0:
            return

        if self.analyze_convergence:
            self.evaluate_CG_iteration(0)

        for cg_iter in num_cg_iter:
            self.run_newton_iter(cg_iter)
            self.hessian_reg *= self.hessian_reg_factor

        if self.debug:
            if not self.analyze_convergence:
                loss = self.problem(self.x)
                self.losses = torch.cat(
                    (self.losses, loss.detach().cpu().view(-1)))

            if self.plotting:
                plot_graph(self.losses, self.fig_num[0], title='Loss')
                plot_graph(self.residuals,
                           self.fig_num[1],
                           title='CG residuals')
                if self.analyze_convergence:
                    plot_graph(self.gradient_mags, self.fig_num[2],
                               'Gradient magnitude')

        self.x.detach_()
        self.clear_temp()

        return self.losses, self.residuals

    def run_newton_iter(self, num_cg_iter):

        self.x.requires_grad_(True)

        # Evaluate function at current estimate
        self.f0 = self.problem(self.x)

        if self.debug and not self.analyze_convergence:
            self.losses = torch.cat(
                (self.losses, self.f0.detach().cpu().view(-1)))

        # Gradient of loss
        self.g = TensorList(
            torch.autograd.grad(self.f0, self.x, create_graph=True))

        # Get the right hand side
        self.b = -self.g.detach()

        # Run CG
        delta_x, res = self.run_CG(num_cg_iter, eps=self.cg_eps)

        self.x.detach_()
        self.x += delta_x

        if self.debug:
            self.residuals = torch.cat((self.residuals, res))

    def A(self, x):
        return TensorList(
            torch.autograd.grad(self.g, self.x, x,
                                retain_graph=True)) + self.hessian_reg * x

    def ip(self, a, b):
        # Implements the inner product
        return self.problem.ip_input(a, b)

    def M1(self, x):
        return self.problem.M1(x)

    def M2(self, x):
        return self.problem.M2(x)

    def evaluate_CG_iteration(self, delta_x):
        if self.analyze_convergence:
            x = (self.x + delta_x).detach()
            x.requires_grad_(True)

            # compute loss and gradient
            loss = self.problem(x)
            grad = TensorList(torch.autograd.grad(loss, x))

            # store in the vectors
            self.losses = torch.cat(
                (self.losses, loss.detach().cpu().view(-1)))
            self.gradient_mags = torch.cat(
                (self.gradient_mags,
                 sum(grad.view(-1)
                     @ grad.view(-1)).cpu().sqrt().detach().view(-1)))
Exemple #30
0
class GaussNewtonCG(ConjugateGradientBase):
    """Gauss-Newton with Conjugate Gradient optimizer."""
    def __init__(self,
                 problem: L2Problem,
                 variable: TensorList,
                 cg_eps=0.0,
                 fletcher_reeves=True,
                 standard_alpha=True,
                 direction_forget_factor=0,
                 debug=False,
                 analyze=False,
                 plotting=False,
                 visdom=None):
        super().__init__(fletcher_reeves, standard_alpha,
                         direction_forget_factor, debug or analyze or plotting)

        self.problem = problem
        self.x = variable

        self.analyze_convergence = analyze
        self.plotting = plotting
        self.fig_num = (10, 11, 12)
        self.visdom = visdom

        self.cg_eps = cg_eps
        self.f0 = None
        self.g = None
        self.dfdxt_g = None

        self.residuals = torch.zeros(0)
        self.losses = torch.zeros(0)
        self.gradient_mags = torch.zeros(0)

    def clear_temp(self):
        self.f0 = None
        self.g = None
        self.dfdxt_g = None

    def run_GN(self, *args, **kwargs):
        return self.run(*args, **kwargs)

    def run(self, num_cg_iter, num_gn_iter=None):
        """Run the optimizer.
        args:
            num_cg_iter: Number of CG iterations per GN iter. If list, then each entry specifies number of CG iterations
                         and number of GN iterations is given by the length of the list.
            num_gn_iter: Number of GN iterations. Shall only be given if num_cg_iter is an integer.
        """

        if isinstance(num_cg_iter, int):
            if num_gn_iter is None:
                raise ValueError(
                    'Must specify number of GN iter if CG iter is constant')
            num_cg_iter = [num_cg_iter] * num_gn_iter

        num_gn_iter = len(num_cg_iter)
        if num_gn_iter == 0:
            return

        if self.analyze_convergence:
            self.evaluate_CG_iteration(0)

        # Outer loop for running the GN iterations.
        for cg_iter in num_cg_iter:
            self.run_GN_iter(cg_iter)

        if self.debug:
            if not self.analyze_convergence:
                self.f0 = self.problem(self.x)
                loss = self.problem.ip_output(self.f0, self.f0)
                self.losses = torch.cat(
                    (self.losses, loss.detach().cpu().view(-1)))

            if self.visdom is not None:
                self.visdom.register(self.losses, 'lineplot', 3, 'Loss')
                self.visdom.register(self.residuals, 'lineplot', 3,
                                     'CG residuals')

                if self.analyze_convergence:
                    self.visdom.register(self.gradient_mags, 'lineplot', 4,
                                         'Gradient magnitude')
            elif self.plotting:
                plot_graph(self.losses, self.fig_num[0], title='Loss')
                plot_graph(self.residuals,
                           self.fig_num[1],
                           title='CG residuals')
                if self.analyze_convergence:
                    plot_graph(self.gradient_mags, self.fig_num[2],
                               'Gradient magnitude')

        self.x.detach_()
        self.clear_temp()
        return self.losses, self.residuals

    def run_GN_iter(self, num_cg_iter):
        """Runs a single GN iteration."""

        self.x.requires_grad_(True)

        # Evaluate function at current estimate
        self.f0 = self.problem(self.x)

        # Create copy with graph detached
        self.g = self.f0.detach()

        if self.debug and not self.analyze_convergence:
            loss = self.problem.ip_output(self.g, self.g)
            # print('Loss:',loss)
            self.losses = torch.cat(
                (self.losses, loss.detach().cpu().view(-1)))

        self.g.requires_grad_(True)

        # Get df/dx^t @ f0
        self.dfdxt_g = TensorList(
            torch.autograd.grad(self.f0, self.x, self.g, create_graph=True))

        # Get the right hand side
        self.b = -self.dfdxt_g.detach()

        # Run CG
        delta_x, res = self.run_CG(num_cg_iter, eps=self.cg_eps)

        self.x.detach_()
        self.x += delta_x

        if self.debug:
            self.residuals = torch.cat((self.residuals, res))

    def A(self, x):
        dfdx_x = torch.autograd.grad(self.dfdxt_g,
                                     self.g,
                                     x,
                                     retain_graph=True)
        return TensorList(
            torch.autograd.grad(self.f0, self.x, dfdx_x, retain_graph=True))

    def ip(self, a, b):
        return self.problem.ip_input(a, b)

    def M1(self, x):
        return self.problem.M1(x)

    def M2(self, x):
        return self.problem.M2(x)

    def evaluate_CG_iteration(self, delta_x):
        if self.analyze_convergence:
            x = (self.x + delta_x).detach()
            x.requires_grad_(True)

            # compute loss and gradient
            f = self.problem(x)
            loss = self.problem.ip_output(f, f)
            grad = TensorList(torch.autograd.grad(loss, x))

            # store in the vectors
            self.losses = torch.cat(
                (self.losses, loss.detach().cpu().view(-1)))
            self.gradient_mags = torch.cat(
                (self.gradient_mags,
                 sum(grad.view(-1)
                     @ grad.view(-1)).cpu().sqrt().detach().view(-1)))