Exemple #1
0
    def __init__(self, train_params: TrainParameters, workspace_dir=None, checkpoint_path=None, comment_msg=None, load_optimizer=True):

        self.verbose_mode = train_params.VERBOSE_MODE
        self.train_params = train_params
        self.load_optimizer = load_optimizer

        # set workspace for temp data, checkpoints etc.
        self.workspace_dir = workspace_dir
        if workspace_dir is not None and not os.path.exists(workspace_dir):
            os.mkdir(workspace_dir)

        # set the device to run training process
        self._set_dev_id(train_params.DEV_ID)

        # load Checkpoints if needed
        self.pre_checkpoint = None if (checkpoint_path is None or not os.path.exists(checkpoint_path)) \
            else dl_util.load_checkpoints(checkpoint_path)

        # set network
        self._set_network()
        if self.model is not None and self.pre_checkpoint is not None and 'net_instance' in self.pre_checkpoint.keys():
            self.model.load_state(self.pre_checkpoint['net_instance'])
            if self.verbose_mode:
                print('[Init. Network] Load Net States from checkpoint: ' + checkpoint_path)
        if self.model is not None:
            self.model.cuda()

        # set the loss function
        self._set_loss_func()
        if self.criterion is not None:
            self.criterion.cuda()

        # set the optimizer
        self._set_optimizer()
        if self.load_optimizer is True:
            if self.optimizer is not None and self.pre_checkpoint is not None and 'optimizer' in self.pre_checkpoint.keys():
                self.optimizer.load_state_dict(self.pre_checkpoint['optimizer'])
                if self.verbose_mode:
                    print('[Init. Optimizer] Load Optimizer from checkpoint: ' + checkpoint_path)

        # set the logger
        self._set_logger(workspace_dir, comment_msg)

        # print comment or tag message
        if self.verbose_mode and comment_msg is not None:
            print('[Comment] -----------------------------------------------------------------------------------------')
            print(comment_msg)

        # save net definition
        self._save_net_def()

        # report the training init
        self.report()

        self.train_start_time = -1
import banet_track.ba_tracknet_mirror_a as mirror_a
import banet_track.ba_tracknet_mirror_b as mirror_b

# PySophus Directory
sys.path.extend(['/opt/eigency', '/opt/PySophus'])
""" Configure ----------------------------------------------------------------------------------------------------------
"""
# Set workspace for temp data, checkpoints etc.
workspace_dir = '/mnt/Tango/banet_track_train/'
if not os.path.exists(workspace_dir):
    os.mkdir(workspace_dir)

# Load Checkpoints if needed
checkpoint_path_a = '/mnt/Tango/banet_track_train/logs/Sep23_23-08-07_cs-gruvi-24-cmpt-sfu-ca_r1.25t0.16_itr1_f96_0.01reg_onlylevel2_b4_no_outside_points_quatloss/checkpoints/iter_003691.pth.tar'
checkpoint_a = None if (checkpoint_path_a is None or not os.path.exists(checkpoint_path_a)) \
    else dl_util.load_checkpoints(checkpoint_path_a)

checkpoint_path_b = '/mnt/Tango/banet_track_train/logs/Sep23_23-04-22_cs-gruvi-24-cmpt-sfu-ca_r1.25t0.16_itr1_f96_onlylevel2_b4_no_outside_points_quatloss/checkpoints/iter_003691.pth.tar'
checkpoint_b = None if (checkpoint_path_b is None or not os.path.exists(checkpoint_path_b)) \
    else dl_util.load_checkpoints(checkpoint_path_b)
""" Prepare the dataset ------------------------------------------------------------------------------------------------
"""
parser = argparse.ArgumentParser()
parser.add_argument('--name',
                    type=str,
                    default='ba_tracknet',
                    help='name of the dataset')
parser.add_argument('--base_dir',
                    type=str,
                    default='/mnt/Tango/datasets/sun3d_demon',
                    help='path to the data directory')
Exemple #3
0
    def __init__(self, train_params: TrainParameters, workspace_dir=None, checkpoint_path=None, comment_msg=None, load_optimizer=True):

        self.verbose_mode = train_params.VERBOSE_MODE
        self.train_params = train_params
        self.load_optimizer = load_optimizer
        self.checkpoint_path = checkpoint_path

        # set workspace for temp dataset, checkpoints etc.
        self.workspace_dir = workspace_dir
        if self.verbose_mode and workspace_dir is not None:
            print('[Log] Directory:' + workspace_dir)
        if workspace_dir is not None and not os.path.exists(workspace_dir):
            os.mkdir(workspace_dir)

        # set the device to run training process
        self._set_dev_id(train_params.DEV_IDS)

        # load Checkpoints if needed
        self.pre_checkpoint = None if (checkpoint_path is None or not os.path.exists(checkpoint_path)) \
            else dl_util.load_checkpoints(checkpoint_path)
        # print(self.pre_checkpoint)
        # set network
        self._set_network()
        if self.pre_checkpoint is not None:
            self._load_network_from_ckpt(self.pre_checkpoint)
            if self.loaded_network is False and self.model is not None and 'net_instance' in self.pre_checkpoint.keys():
                self.model.load_state(self.pre_checkpoint['net_instance'])
                self.loaded_network = True
                if self.verbose_mode:
                    print('[Init. Network] Load net States from checkpoint: ' + checkpoint_path)

        # set the loss function
        self._set_loss_func()

        # set the optimizer
        self._set_optimizer()
        if self.load_optimizer is True:
            if self.optimizer is not None and self.pre_checkpoint is not None and 'optimizer' in self.pre_checkpoint.keys():
                self.optimizer.load_state_dict(self.pre_checkpoint['optimizer'])
                if self.verbose_mode:
                    print('[Init. Optimizer] Load Optimizer from checkpoint: ' + checkpoint_path)

        # print comment or tag message
        if self.verbose_mode and comment_msg is not None:
            print('[Tag] -------------------------------------------------------------------------------------------')
            print(comment_msg)

        # set the logger
        self._set_logger(workspace_dir, comment_msg)
        if self.verbose_mode and self.logger is not None:
            self.logger.meta_dict['dev_id'] = train_params.DEV_IDS
            self.logger.meta_dict['start_learning_rate'] = train_params.START_LR
            print('[Logger] Dir: %s' % self.logger.log_base_dir)

            logger = self.logger.get_tensorboard_writer()
            if logger is not None:
                logger.add_text(tag='Description', text_string=self.train_params.DESCRIPTION, global_step=0)
                logger.add_text(tag='Directory', text_string=self.logger.log_base_dir, global_step=0)
                logger.add_text(tag='Parameters', text_string=self.train_params.to_json(), global_step=0)

        # save net definition
        self.model_def_dir = None
        if self.logger is not None:
            self.model_def_dir = os.path.join(self.logger.log_base_dir, 'net_def')
            if not os.path.exists(self.model_def_dir):
                os.mkdir(self.model_def_dir)
            # self.model.save_net_def(self.model_def_dir)
            self.train_params.save(os.path.join(self.model_def_dir, 'train_param.json'))
            self._save_net_def(self.model_def_dir)

        # report the training init
        self.report()

        self.train_start_time = -1

        # state variables
        self.is_training = False
Exemple #4
0
 def load_from_file(self, file_name, instance_key='net_instance'):
     ckpt = load_checkpoints(file_name)
     self.load_state(ckpt[instance_key])