예제 #1
0
 def __init__(self, config, model, data_feature):
     self.config = config
     self.device = self.config.get('device', torch.device('cpu'))
     self.model = model.to(self.device)
     self.evaluator = get_evaluator(config)
     self.exp_id = self.config.get('exp_id', None)
     self.cache_dir = './libcity/cache/{}/model_cache'.format(self.exp_id)
     self.evaluate_res_dir = './libcity/cache/{}/evaluate_cache'.format(self.exp_id)
     self.tmp_path = './libcity/tmp/checkpoint/'
예제 #2
0
 def __init__(self, config, model, data_feature):
     self.model = model
     self.config = config
     self.evaluator = get_evaluator(config)
     self.exp_id = self.config.get('exp_id', None)
     self.cache_dir = './libcity/cache/{}/model_cache'.format(self.exp_id)
     self.evaluate_res_dir = './libcity/cache/{}/evaluate_cache'.format(
         self.exp_id)
     self._logger = getLogger()
 def __init__(self, config, model, data_feature):
     self.evaluator = get_evaluator(config)
     self.metrics = 'Recall@{}'.format(config['topk'])
     self.config = config
     self.model = model.to(self.config['device'])
     self.tmp_path = './libcity/tmp/checkpoint/'
     self.exp_id = self.config.get('exp_id', None)
     self.cache_dir = './libcity/cache/{}/model_cache'.format(self.exp_id)
     self.evaluate_res_dir = './libcity/cache/{}/evaluate_cache'.format(
         self.exp_id)
     self.loss_func = None  # TODO: 根据配置文件支持选择特定的 Loss Func 目前并未实装
     self._logger = getLogger()
     self.optimizer = self._build_optimizer()
     self.scheduler = self._build_scheduler()
    def __init__(self, config, model, data_feature):
        self.evaluator = get_evaluator(config)
        self.config = config
        self.data_feature = data_feature
        self.device = self.config.get('device', torch.device('cpu'))
        self.model = model
        self.exp_id = self.config.get('exp_id', None)

        self.cache_dir = './libcity/cache/{}/model_cache'.format(self.exp_id)
        self.evaluate_res_dir = './libcity/cache/{}/evaluate_cache'.format(self.exp_id)

        ensure_dir(self.cache_dir)
        ensure_dir(self.evaluate_res_dir)

        self._logger = getLogger()
        self._scaler = self.data_feature.get('scaler')

        self.output_dim = self.config.get('output_dim', 1)
    def __init__(self, config, model, data_feature):
        self.evaluator = get_evaluator(config)
        self.config = config
        self.data_feature = data_feature
        self.device = self.config.get('device', torch.device('cpu'))
        self.model = model.to(self.device)
        self.exp_id = self.config.get('exp_id', None)

        self.cache_dir = './libcity/cache/{}/model_cache'.format(self.exp_id)
        self.evaluate_res_dir = './libcity/cache/{}/evaluate_cache'.format(
            self.exp_id)
        self.summary_writer_dir = './libcity/cache/{}/'.format(self.exp_id)
        ensure_dir(self.cache_dir)
        ensure_dir(self.evaluate_res_dir)
        ensure_dir(self.summary_writer_dir)

        self._writer = SummaryWriter(self.summary_writer_dir)
        self._logger = getLogger()
        self._scaler = self.data_feature.get('scaler')
        self._logger.info(self.model)
        for name, param in self.model.named_parameters():
            self._logger.info(
                str(name) + '\t' + str(param.shape) + '\t' +
                str(param.device) + '\t' + str(param.requires_grad))
        total_num = sum(
            [param.nelement() for param in self.model.parameters()])
        self._logger.info('Total parameter numbers: {}'.format(total_num))

        self.epochs = self.config.get('max_epoch', 100)
        self.train_loss = self.config.get('train_loss', 'none')
        self.learner = self.config.get('learner', 'adam')
        self.learning_rate = self.config.get('learning_rate', 0.01)
        self.weight_decay = self.config.get('weight_decay', 0)
        self.lr_beta1 = self.config.get('lr_beta1', 0.9)
        self.lr_beta2 = self.config.get('lr_beta2', 0.999)
        self.lr_betas = (self.lr_beta1, self.lr_beta2)
        self.lr_alpha = self.config.get('lr_alpha', 0.99)
        self.lr_epsilon = self.config.get('lr_epsilon', 1e-8)
        self.lr_momentum = self.config.get('lr_momentum', 0)
        self.lr_decay = self.config.get('lr_decay', False)
        self.lr_scheduler_type = self.config.get('lr_scheduler', 'multisteplr')
        self.lr_decay_ratio = self.config.get('lr_decay_ratio', 0.1)
        self.milestones = self.config.get('steps', [])
        self.step_size = self.config.get('step_size', 10)
        self.lr_lambda = self.config.get('lr_lambda', lambda x: x)
        self.lr_T_max = self.config.get('lr_T_max', 30)
        self.lr_eta_min = self.config.get('lr_eta_min', 0)
        self.lr_patience = self.config.get('lr_patience', 10)
        self.lr_threshold = self.config.get('lr_threshold', 1e-4)
        self.clip_grad_norm = self.config.get('clip_grad_norm', False)
        self.max_grad_norm = self.config.get('max_grad_norm', 1.)
        self.use_early_stop = self.config.get('use_early_stop', False)
        self.patience = self.config.get('patience', 50)
        self.log_every = self.config.get('log_every', 1)
        self.saved = self.config.get('saved_model', True)
        self.load_best_epoch = self.config.get('load_best_epoch', True)
        self.hyper_tune = self.config.get('hyper_tune', False)

        self.output_dim = self.config.get('output_dim', 1)
        self.optimizer = self._build_optimizer()
        self.lr_scheduler = self._build_lr_scheduler()
        self._epoch_num = self.config.get('epoch', 0)
        if self._epoch_num > 0:
            self.load_model_with_epoch(self._epoch_num)
        self.loss_func = self._build_train_loss()