def setup_pose_net(config, prepared, **kwargs): """ Create a pose network Parameters ---------- config : CfgNode Network configuration prepared : bool True if the network has been prepared before kwargs : dict Extra parameters for the network Returns ------- pose_net : nn.Module Created pose network """ print0(pcolor('PoseNet: %s' % config.name, 'yellow')) pose_net = load_class_args_create( config.name, paths=[ 'packnet_sfm.networks.pose', ], args={ **config, **kwargs }, ) if not prepared and config.checkpoint_path is not '': pose_net = load_network(pose_net, config.checkpoint_path, ['pose_net', 'pose_network']) return pose_net
def setup_model(config, prepared, **kwargs): """ Create a model Parameters ---------- config : CfgNode Model configuration (cf. configs/default_config.py) prepared : bool True if the model has been prepared before kwargs : dict Extra parameters for the model Returns ------- model : nn.Module Created model """ print0(pcolor('Model: %s' % config.name, 'yellow')) model = load_class(config.name, paths=['packnet_sfm.models',])( **{**config.loss, **kwargs}) # Add depth network if required if 'depth_net' in model.network_requirements: model.add_depth_net(setup_depth_net(config.depth_net, prepared)) # Add pose network if required if 'pose_net' in model.network_requirements: model.add_pose_net(setup_pose_net(config.pose_net, prepared)) # If a checkpoint is provided, load pretrained model if not prepared and config.checkpoint_path is not '': model = load_network(model, config.checkpoint_path, 'model') # Return model return model
def setup_depth_net(config, prepared, **kwargs): """ Create a depth network Parameters ---------- config : CfgNode Network configuration prepared : bool True if the network has been prepared before kwargs : dict Extra parameters for the network Returns ------- depth_net : nn.Module Create depth network """ print0(pcolor("DepthNet: %s" % config.name, "yellow")) depth_net = load_class_args_create( config.name, paths=[ "packnet_sfm.networks.depth", ], args={ **config, **kwargs }, ) if not prepared and config.checkpoint_path is not "": depth_net = load_network(depth_net, config.checkpoint_path, ["depth_net", "disp_network"]) return depth_net
def prepare_model(self, resume=None): """Prepare self.model (incl. loading previous state)""" print0(pcolor('### Preparing Model', 'green')) self.model = setup_model(self.config.model, self.config.prepared) # Resume model if available if resume: print0(pcolor('### Resuming from {}'.format( resume['file']), 'magenta', attrs=['bold'])) self.model = load_network( self.model, resume['state_dict'], 'model') if 'epoch' in resume: self.current_epoch = resume['epoch']
def prepare_model(self, resume=None): """Prepare self.model (incl. loading previous state)""" print0(pcolor("### Preparing Model", "green")) self.model = setup_model(self.config.model, self.config.prepared) # Resume model if available if resume: print0( pcolor( "### Resuming from {}".format(resume["file"]), "magenta", attrs=["bold"], )) self.model = load_network(self.model, resume["state_dict"], "model") if "epoch" in resume: self.current_epoch = resume["epoch"]
def setup_model(config, prepared, **kwargs): """ Create a model Parameters ---------- config : CfgNode Model configuration (cf. configs/default_config.py) prepared : bool True if the model has been prepared before kwargs : dict Extra parameters for the model Returns ------- model : nn.Module Created model """ print0(pcolor('Model: %s' % config.name, 'yellow')) # SfmModel, SelfSupModel, VelSupModel loaded model = load_class(config.name, paths=['packnet_sfm.models',])( **{**config.loss, **kwargs}) # Add depth network if required if model.network_requirements['depth_net']: model.add_depth_net(setup_depth_net(config.depth_net, prepared, num_scales=config.loss.num_scales, min_depth=config.params.min_depth, max_depth=config.params.max_depth, upsample_depth_maps=config.loss.upsample_depth_maps )) # Add pose network if required if model.network_requirements['pose_net']: model.add_pose_net( setup_pose_net(config.pose_net, prepared, rotation_mode=config.loss.rotation_mode, **kwargs)) # If a checkpoint is provided, load pretrained model if not prepared and config.checkpoint_path is not '': model = load_network(model, config.checkpoint_path, 'model') # Return model return model