コード例 #1
0
    def set_network_gradient_op(self, gradients):
        """
        create gradient op by optimiser.apply_gradients
        this function sets ``self.gradient_op``.

        Override this function for more complex optimisations such as
        using different optimisers for sub-networks.

        :param gradients: processed gradients from the gradient_collector
        :return:
        """
        grad_list_depth = util_common.list_depth_count(gradients)
        if grad_list_depth == 3:
            # nested depth 3 means: gradients list is nested in terms of:
            # list of networks -> list of network variables
            self.gradient_op = [
                self.optimiser.apply_gradients(grad) for grad in gradients
            ]
        elif grad_list_depth == 2:
            # nested depth 2 means:
            # gradients list is a list of variables
            self.gradient_op = self.optimiser.apply_gradients(gradients)
        else:
            raise NotImplementedError(
                'This app supports updating a network, or a list of networks.')
コード例 #2
0
 def set_network_update_op(self, gradients):
     grad_list_depth = util_common.list_depth_count(gradients)
     if grad_list_depth == 3:
         # nested depth 3 means: gradients list is nested in terms of:
         # list of networks -> list of network variables
         self.gradient_op = [self.optimiser.apply_gradients(grad)
                             for grad in gradients]
     elif grad_list_depth == 2:
         # nested depth 2 means:
         # gradients list is a list of variables
         self.gradient_op = self.optimiser.apply_gradients(gradients)
     else:
         raise NotImplementedError(
             'This app supports updating a network, or list of networks')
コード例 #3
0
def _apply_gradients(optimiser, gradients):
    """
    Apply gradients op by ``optimiser.apply_gradients``.

    :param optimiser: single optimiser processing the passed gradients
    :param gradients: processed gradients from the gradient_collector
    :return:
    """

    grad_list_depth = util_common.list_depth_count(gradients)
    if grad_list_depth == 3:
        # nested depth 3 means: gradients list is nested in terms of:
        # list of networks -> list of network variables
        return [optimiser.apply_gradients(grad) for grad in gradients]
    elif grad_list_depth == 2:
        # nested depth 2 means:
        # gradients list is a list of variables
        return optimiser.apply_gradients(gradients)
    raise NotImplementedError(
        'This app supports updating a network, or a list of networks.')
コード例 #4
0
ファイル: base_application.py プロジェクト: fepegar/NiftyNet
    def set_network_gradient_op(self, gradients):
        """
        create gradient op by optimiser.apply_gradients
        this function sets ``self.gradient_op``.

        Override this function for more complex optimisations such as
        using different optimisers for sub-networks.

        :param gradients: processed gradients from the gradient_collector
        :return:
        """
        grad_list_depth = util_common.list_depth_count(gradients)
        if grad_list_depth == 3:
            # nested depth 3 means: gradients list is nested in terms of:
            # list of networks -> list of network variables
            self.gradient_op = [self.optimiser.apply_gradients(grad)
                                for grad in gradients]
        elif grad_list_depth == 2:
            # nested depth 2 means:
            # gradients list is a list of variables
            self.gradient_op = self.optimiser.apply_gradients(gradients)
        else:
            raise NotImplementedError(
                'This app supports updating a network, or a list of networks.')