Ejemplo n.º 1
0
    def __init__(self, prototxt):
        """Construct a Solver.

        Parameters
        ----------
        prototxt : str
            The path of ``.prototxt`` file.

        Returns
        -------
        Solver
            The solver.

        Examples
        --------
        >>> solver = Solver('solver.prototxt')

        """
        self._param = pb.SolverParameter()
        Parse(open(prototxt, 'r').read(), self._param)
        self.ParseUpdateParam()
        self._net = None
        self._test_nets = []
        self._layer_blobs = []
        self._iter = self._current_step = 0
        self._optimizer = None
        self.scalar_writer = sw.ScalarSummary() if root_solver() else None

        self.InitTrainNet()
        self.InitTestNets()
        self.BuildNets()
Ejemplo n.º 2
0
    def __init__(self, proto_txt):
        """Construct a Solver.

        Parameters
        ----------
        proto_txt : str
            The path of ``.prototxt`` file.

        Returns
        -------
        Solver
            The solver.

        Examples
        --------
        >>> solver = Solver('solver.prototxt')

        """
        self._param = _proto_def.SolverParameter()
        _parse_text_proto(open(proto_txt, 'r').read(), self._param)
        if self._param.iter_size > 1:
            raise NotImplementedError('Gradients accumulating is deprecated.')
        self._net = None
        self._test_nets = []
        self._layer_blobs = []
        self._iter = self._current_step = 0
        self.optimizer = None
        self.InitTrainNet()
        self.InitTestNets()
        self.BuildNets()
        self.ParseOptimizerArguments()
Ejemplo n.º 3
0
 def __init__(self, prototxt):
     self._param = pb.SolverParameter()
     Parse(open(prototxt, 'r').read(), self._param)
     self._net = None; self._test_nets = []
     self._iter = self._current_step = 0
     self.train = self.tests = self.update = self._updater = None
     self.scalar_writer = sw.ScalarSummary() if root_solver() else None
     self._lr_blobs = []
     self.InitTrainNet()
     self.InitTestNets()
     self.CheckUpdateParam()
Ejemplo n.º 4
0
    def __init__(self, solver_prototxt, output_dir, pretrained_model=None):

        self.output_dir = output_dir

        self.solver = caffe.SGDSolver(solver_prototxt)

        if pretrained_model is not None:
            print('Loading pretrained model '
                  'weights from {:s}').format(pretrained_model)
            self.solver.net.copy_from(pretrained_model)

        self.solver_param = caffe_pb2.SolverParameter()
        with open(solver_prototxt, 'rt') as f:
            pb2.text_format.Merge(f.read(), self.solver_param)