示例#1
0
 def run_solver_loop(self):
     solver_def = 'solver.prototxt'
     solver_param = caffe_pb2.SolverParameter()
     text_format.Merge(open(solver_def).read(), solver_param)
     solver = caffe.get_solver_from_string(solver_param.SerializeToString())
     for i in range(50):
         solver.step(1)
示例#2
0
 def run_solver_loop(self):
     solver_def='solver.prototxt'
     solver_param = caffe_pb2.SolverParameter()
     text_format.Merge(open(solver_def).read(), solver_param)
     solver = caffe.get_solver_from_string(solver_param.SerializeToString())
     for i in range(50):
         solver.step(1)
    def __init__(self,
                 prototxt,
                 final_file=None,
                 snap_file=None,
                 solver='Adam',
                 log_file=None,
                 **kwargs):
        self.running = False
        self.force_snapshot = False
        self.last_interrupt = 0
        self.final_file, self.snap_file = final_file, snap_file
        if final_file is None:
            print("Are you sure you dont want to save the model?")
        solver_str = 'train_net: "%s"\n' % prototxt
        if solver is not None:
            # Get the cases right
            if solver.upper() == "ADAM": solver = "Adam"
            if solver.upper() == "ADADELTA": solver = "AdaDelta"
            if solver.upper() == "ADAGRAD": solver = "AdaGrad"
            if solver.upper() == "NESTEROV": solver = "Nesterov"
            if solver.upper() == "RMSPROP": solver = "RMSProp"
            if solver.upper() == "SGD": solver = "SGD"
            solver_str += 'type: "%s"\n' % solver
            if solver == "RMSProp":
                if 'rms_decay' not in kwargs:
                    kwargs['rms_decay'] = 0.9
            if solver == "SGD":
                if 'momentum' not in kwargs:
                    kwargs['momentum'] = 0.9
            if solver == "Adam":
                if 'momentum' not in kwargs:
                    kwargs['momentum'] = 0.9
                if 'momentum2' not in kwargs:
                    kwargs['momentum2'] = 0.99
        if 'base_lr' not in kwargs: kwargs['base_lr'] = 0.001
        if 'lr_policy' not in kwargs: kwargs['lr_policy'] = 'fixed'
        for i in kwargs:
            if isinstance(kwargs[i], str):
                solver_str += '%s: "%s"\n' % (i, kwargs[i])
            elif isinstance(kwargs[i], int):
                solver_str += '%s: %d\n' % (i, kwargs[i])
            else:
                solver_str += '%s: %f\n' % (i, kwargs[i])
        self.solver = caffe.get_solver_from_string(solver_str)
        self.solver.add_callback(self.on_start, self.on_gradient)

        self.log = None
        if log_file:
            self.log = open(log_file, 'w')
示例#4
0
    def __init__(self,
                 solver_def,
                 solver_state=None,
                 weights=None,
                 gpus=[0],
                 log_dir='log',
                 log_db_prefix='log_db'):
        """
        Acts as a driver for training with smart logging.
        Logs will be stored to a MySQL database.
        All paths can be set relative to the location of the solver prototxt.

        :param solver_def:      prototxt that defines the solver
        :param solver_state:    optional: a .solverstate file from which to resume training \  NEVER SET THESE TWO
        :param weights:         optional: a .caffemodel file from which to begin finetuning /   AT THE SAME TIME
        :param gpus:            optional: a list of GPU IDs to use for (multi-)GPU training
                                if set to None, caffe will operate in CPU mode
        :param log_dir:         optional: will log into this directory under solver.prototxt
        :param log_db_prefix:   prefix for both SQLite db names and table names

        The following parameters should to be set in the solver prototxt file:
        log_interval            log per this number of iterations (simple log) [default = 20]
        viz_interval            log visualization per this number of iterations (net blobs snapshot) [default = 100]
        test_iter:              The number of iterations for each test net.
        """
        if not os.path.isabs(solver_def):
            if not os.path.isfile(os.path.join(os.getcwd(), solver_def)):
                os.chdir('..')
                solver_def = os.path.join(os.getcwd(), solver_def)
            else:
                solver_def = os.path.join(os.getcwd(), solver_def)

        self.solver_dir = solver_def[:solver_def.rfind('/')]
        os.chdir(self.solver_dir)
        self.log_dir = os.path.join(self.solver_dir, log_dir)
        if not os.path.exists(self.log_dir):
            os.mkdir(self.log_dir)
        self.logprint("Logging to {}".format(self.log_dir))
        self.log_db_prefix = log_db_prefix  # used for db name and as a prefix for tables
        self.solver_param = caffe_pb2.SolverParameter()
        text_format.Merge(open(solver_def).read(), self.solver_param)

        # read params from solver definition
        self.iterations = self.solver_param.max_iter
        self.log_interval = self.solver_param.log_interval
        self.viz_interval = self.solver_param.viz_interval
        self.test_interval = self.solver_param.test_interval
        self.se_index_blob = self.solver_param.se_index_blob
        self.se_error_blob = self.solver_param.se_error_blob
        self.se_interval = self.solver_param.se_interval
        self.se_indices = []
        self.se_errors = []

        # make solver param for net fail safe
        if not os.path.isabs(self.solver_param.net):
            self.solver_param.net = os.path.join(self.solver_dir,
                                                 self.solver_param.net)
            if not os.path.isfile(self.solver_param.net):
                raise Exception(
                    'could not find net definition from solver prototxt!')

        self.gpus = gpus
        if gpus:
            self.solver_param.device_id = gpus[0]
            caffe.set_device(gpus[0])
            caffe.set_mode_gpu()
            caffe.set_solver_count(len(gpus))

        self.solver = caffe.get_solver_from_string(
            self.solver_param.SerializeToString())

        if solver_state:
            # check if this file is in the current (or parent) directory or if the solver path needs to be prepended
            if not os.path.isfile(os.path.join('..', solver_state)):
                if not os.path.isfile(solver_state):
                    solver_state = os.path.join(self.solver_dir, solver_state)
                    if not os.path.isfile(solver_state):
                        raise Exception(
                            'could not find solver state specified!')
            else:
                solver_state = os.path.join('..', solver_state)
            self.solver.restore(solver_state)

            if weights:
                raise Exception(
                    'should not specify both solverstate and caffemodel! Preference will be given to solverstate.'
                )

        if weights and not solver_state:
            self.solver.net.copy_from(weights)

        self.sync = None
        self.viz_thread = None
        self.log_thread = None
        self.test_input = None
        self.test_out_blobs = None
        self.test_start_layer = None
        self.iteration = 0

        # check if blobs update_sample_errors exist
        if self.se_interval and not self.se_index_blob in self.solver.net.blobs:
            self.logprint(
                "WARNING: index_blob not found in net! Won't send errors to Net."
            )
            self.se_interval = 0

        if self.se_interval and not self.se_error_blob in self.solver.net.blobs:
            self.logprint(
                "WARNING: error_blob not found in net! Won't send errors to Net."
            )
            self.se_interval = 0
示例#5
0
    def __init__(self, solver_def, solver_state=None, weights=None, gpus=[0], log_dir='log', log_db_prefix='log_db'):
        """
        Acts as a driver for training with smart logging.
        Logs will be stored to a MySQL database.
        All paths can be set relative to the location of the solver prototxt.

        :param solver_def:      prototxt that defines the solver
        :param solver_state:    optional: a .solverstate file from which to resume training \  NEVER SET THESE TWO
        :param weights:         optional: a .caffemodel file from which to begin finetuning /   AT THE SAME TIME
        :param gpus:            optional: a list of GPU IDs to use for (multi-)GPU training
                                if set to None, caffe will operate in CPU mode
        :param log_dir:         optional: will log into this directory under solver.prototxt
        :param log_db_prefix:   prefix for both SQLite db names and table names

        The following parameters should to be set in the solver prototxt file:
        log_interval            log per this number of iterations (simple log) [default = 20]
        viz_interval            log visualization per this number of iterations (net blobs snapshot) [default = 100]
        test_iter:              The number of iterations for each test net.
        """
        if not os.path.isabs(solver_def):
            if not os.path.isfile(os.path.join(os.getcwd(), solver_def)):
                os.chdir('..')
                solver_def = os.path.join(os.getcwd(), solver_def)
            else:
                solver_def = os.path.join(os.getcwd(), solver_def)

        self.solver_dir = solver_def[:solver_def.rfind('/')]
        os.chdir(self.solver_dir)
        self.log_dir = os.path.join(self.solver_dir, log_dir)
        if not os.path.exists(self.log_dir):
            os.mkdir(self.log_dir)
        self.logprint("Logging to {}".format(self.log_dir))
        self.log_db_prefix = log_db_prefix  # used for db name and as a prefix for tables
        self.solver_param = caffe_pb2.SolverParameter()
        text_format.Merge(open(solver_def).read(), self.solver_param)

        # read params from solver definition
        self.iterations = self.solver_param.max_iter
        self.log_interval = self.solver_param.log_interval
        self.viz_interval = self.solver_param.viz_interval
        self.test_interval = self.solver_param.test_interval
        self.se_index_blob = self.solver_param.se_index_blob
        self.se_error_blob = self.solver_param.se_error_blob
        self.se_interval = self.solver_param.se_interval
        self.se_indices = []
        self.se_errors = []

        # make solver param for net fail safe
        if not os.path.isabs(self.solver_param.net):
            self.solver_param.net = os.path.join(self.solver_dir, self.solver_param.net)
            if not os.path.isfile(self.solver_param.net):
                raise Exception('could not find net definition from solver prototxt!')

        self.gpus = gpus
        if gpus:
            self.solver_param.device_id = gpus[0]
            caffe.set_device(gpus[0])
            caffe.set_mode_gpu()
            caffe.set_solver_count(len(gpus))

        self.solver = caffe.get_solver_from_string(self.solver_param.SerializeToString())

        if solver_state:
            # check if this file is in the current (or parent) directory or if the solver path needs to be prepended
            if not os.path.isfile(os.path.join('..', solver_state)):
                if not os.path.isfile(solver_state):
                    solver_state = os.path.join(self.solver_dir, solver_state)
                    if not os.path.isfile(solver_state):
                        raise Exception('could not find solver state specified!')
            else:
                solver_state = os.path.join('..', solver_state)
            self.solver.restore(solver_state)

            if weights:
                raise Exception(
                    'should not specify both solverstate and caffemodel! Preference will be given to solverstate.')

        if weights and not solver_state:
            self.solver.net.copy_from(weights)

        self.sync = None
        self.viz_thread = None
        self.log_thread = None
        self.test_input = None
        self.test_out_blobs = None
        self.test_start_layer = None
        self.iteration = 0

        # check if blobs update_sample_errors exist
        if self.se_interval and not self.se_index_blob in self.solver.net.blobs:
            self.logprint("WARNING: index_blob not found in net! Won't send errors to Net.")
            self.se_interval = 0

        if self.se_interval and not self.se_error_blob in self.solver.net.blobs:
            self.logprint("WARNING: error_blob not found in net! Won't send errors to Net.")
            self.se_interval = 0