Beispiel #1
0
    def act(self, callback: AbstractCallback, log_fun: Callable, trainer: AbstractTrainerFunctions,
            pl_module: AbstractMethod, pl_ema_module: ModelEMA = None):
        """
        lazy way of handling the response, so that the related code is in one place
        """
        # saving model weights
        if isinstance(self.save_path, str):
            log_fun("saving to: %s, prefer ema=%s" % (self.save_path, str(self.save_ema)))
            CheckpointCallback.save(self.save_path,
                                    callback.get_method(pl_module, pl_ema_module, prefer_ema=self.save_ema))
        # loading model weights
        if isinstance(self.load_path, str):
            log_fun("loading from: %s" % self.load_path)
            CheckpointCallback.wait_load(self.load_path)

        # learning rates
        if len(self.optimizer_lrs) > 0:
            optimizers = trainer.get_optimizers()
            for optimizer_id, v in self.optimizer_lrs.items():
                log_fun("setting learning rate of optimizer %d to %f" % (optimizer_id, v))
                WrappedOptimizer.set_optimizer_lr_by_index(optimizers, optimizer_id, lr=v, is_multiplier=False)

        # regularizer values
        for k, v in self.reqularizer_values.items():
            for regularizer in trainer.get_regularizers():
                if regularizer.__class__.__name__ == k:
                    log_fun("setting %s value to %s" % (k, str(v)))
                    regularizer.set_value(v)
Beispiel #2
0
    def _build(self, s_in: Shape, s_out: Shape) -> ShapeList:
        """ build the network, count params, log, maybe load pretrained weights """
        assert isinstance(s_out, Shape), "Attempting to build a network with an output that is not a Shape!"
        s_out_copy = s_out.copy(copy_id=True)
        self.shape_in = s_in.copy(copy_id=True)
        s_out_net = self._build2(s_in, s_out)
        LoggerManager().get_logger().info('Network built, it has %d parameters!' % self.get_num_parameters())

        # validate output shape sizes
        assert isinstance(s_out_net, ShapeList), "The network must output a list of Shapes, one shape per head! (ShapeList)"
        for shape in s_out_net.shapes:
            if not s_out_copy == shape:
                text = "One or more output shapes mismatch: %s, expected: %s" % (s_out_net, s_out_copy)
                if self.assert_output_match:
                    raise ValueError(text)
                else:
                    LoggerManager().get_logger().warning(text)
                    break

        # load weights?
        if len(self.checkpoint_path) > 0:
            path = CheckpointCallback.find_pretrained_weights_path(self.checkpoint_path, self.model_name,
                                                                   raise_missing=len(self.checkpoint_path) > 0)
            num_replacements = 1 if self.is_external() else 999
            self.loaded_weights(CheckpointCallback.load_network(path, self.get_network(), num_replacements))

        self.shape_out = s_out_net.shapes[0].copy(copy_id=True)
        self.shape_in_list = self.shape_in.shape
        self.shape_out_list = self.shape_out.shape
        return s_out_net
Beispiel #3
0
 def train_epochs(self, epochs=1, run_eval=True, run_test=True):
     """ train 'epochs' epochs, includes eval/test for the last n epochs """
     assert len(
         self.callbacks
     ) > 0, "DDP requires a checkpoint callback to recover weights later"
     self.mover.empty_cache()
     args = (self.mover.get_num_devices(), self.method, self.save_dir,
             self.mover, self.callbacks, self.exp_logger, epochs,
             self.eval_last, self.test_last, self.is_test_run,
             self.ema_decay, self.ema_device, self.use_sync_bn,
             self._state_to_load)
     # create threads, join them, always try to clean up
     context = mp.spawn(SimpleDDPTrainerTrainEpochsImpl,
                        args=args,
                        nprocs=self.mover.get_num_devices(),
                        join=False)
     try:
         while not context.join():
             pass
         self.resource_logger.wakeup()
     except Exception as e:
         raise e
     finally:
         self.mover.empty_cache()
         self.resource_logger.stop()
     CheckpointCallback.load_last_checkpoint(self.save_dir, self.method)
Beispiel #4
0
def get_network(config_path: str, input_shape: Shape, output_shape: Shape, weights_path: str = None) -> AbstractUninasNetwork:
    """
    create a network (model) from a config file, optionally load weights
    """
    builder = Builder()

    # get a new network
    network = builder.load_from_config(Builder.find_net_config_path(config_path))
    network = AbstractUninasNetwork(model_name="standalone", net=network, checkpoint_path="", assert_output_match=True)
    network.build(s_in=input_shape, s_out=output_shape)

    # load network weights; they are saved from a method, so the keys have to be mapped accordingly
    if isinstance(weights_path, str):
        CheckpointCallback.load_network(weights_path, network, num_replacements=1)

    return network
Beispiel #5
0
 def _load_ddp(self,
               module: AbstractMethod,
               save_dir: str,
               prefer_ema=True):
     # not used in each process, no need to care for regular/ema model distinction
     file_names = ['checkpoint.tmp.pt', 'checkpoint.ema.tmp.pt']
     file_names = reversed(file_names) if prefer_ema else file_names
     for fn in file_names:
         file = SimpleDDPTrainer.checkpoint_file(save_dir, fn)
         if CheckpointCallback.load(file_path=file, pl_module=module):
             break
Beispiel #6
0
 def load(self, file: str) -> bool:
     """ load training state from file """
     checkpoint = CheckpointCallback.load_last_checkpoint(
         save_dir=file, pl_module=self.method)
     self._load_state_dict(checkpoint.get('trainer_state', {}))
     return len(checkpoint) > 0
Beispiel #7
0
 def save(self, file: str):
     """ save training state to file """
     CheckpointCallback.save(file_path=file,
                             pl_module=self.method,
                             update_dict=self.get_checkpoint_update_dict())
Beispiel #8
0
    def _initialize_weights(self, net: AbstractModule, logger: logging.Logger):
        assert isinstance(
            net, AbstractUninasNetwork
        ), "This initializer will not work with external networks!"
        search_config = Builder.find_net_config_path(self.path,
                                                     pattern='search')

        checkpoint = CheckpointCallback.load_last_checkpoint(self.path)
        state_dict = checkpoint.get('state_dict')

        # figure out correct weights in super-network checkpoint
        if len(self.gene) > 0:
            log_headline(logger,
                         "tmp network to track used params",
                         target_len=80)
            sm = StrategyManager()
            tmp_s = RandomChoiceStrategy(max_epochs=1, name='__tmp__')
            assert len(sm.get_strategies_list(
            )) == 0, "can not load when there already is a search network"
            sm.add_strategy(tmp_s)
            sm.set_fixed_strategy_name('__tmp__')

            search_net = Builder().load_from_config(search_config)
            assert isinstance(search_net, SearchUninasNetwork)
            s_in, s_out = net.get_shape_in(), net.get_shape_out()
            search_net.build(s_in, s_out[0])
            search_net.set_forward_strategy(False)
            search_net.forward_strategy(fixed_arc=self.gene)
            tracker = search_net.track_used_params(
                s_in.random_tensor(batch_size=2))
            # tracker.print()

            logger.info(' > loading weights of gene %s from checkpoint "%s"' %
                        (str(self.gene), self.path))
            target_dict = net.state_dict()
            target_names = list(target_dict.keys())
            new_dict = {}

            # add all stem and head weights, they are at the front of the dict and have pretty much the same name
            log_columns = [('shape in checkpoint', 'name in checkpoint',
                            'name in network', 'shape in network')]
            for k, v in state_dict.items():
                if '.stem.' in k or '.heads.' in k:
                    tn = target_names.pop(0)
                    ts = target_dict[tn].shape
                    log_columns.append(
                        (str(list(v.shape)), k, tn, str(list(ts))))
                    n = k.replace('net.', '', 1)
                    assert n == tn
                    new_dict[n] = v

            # add all cell weights, can generally not compare names, only shapes
            for i, tracker_cell_entry in enumerate(tracker.get_cells()):
                for entry in tracker_cell_entry.get_pareto_best():
                    tn = target_names.pop(0)
                    ts = target_dict[tn].shape
                    log_columns.append((str(list(entry.shape)), entry.name, tn,
                                        str(list(ts))))
                    assert entry.shape == ts,\
                        'Mismatching shapes for "%s" and "%s", is the gene correct?' % (entry.name, tn)
                    new_dict[tn] = state_dict[entry.name]

            # log matches, load
            log_in_columns(logger, log_columns, add_bullets=True)
            net.load_state_dict(new_dict, strict=self.strict)

            # clean up
            del search_net
            sm.delete_strategy('__tmp__')
            del sm

        # simply load
        else:
            logger.info(' > simply loading state_dict')
            net.load_state_dict(state_dict, strict=self.strict)