示例#1
0
    def run(self, model_path, model_indice, log_file, log_file_link):
        """There are four evaluation modes:
            1.only eval a .pth model: -e *.pth
            2.only eval a certain epoch: -e epoch
            3.eval all epochs in a given section: -e start_epoch-end_epoch
            4.eval all epochs from a certain started epoch: -e start_epoch-
            """
        if '.pth' in model_indice:
            models = [
                model_indice,
            ]
        elif "-" in model_indice:
            start_epoch = int(model_indice.split("-")[0])
            end_epoch = model_indice.split("-")[1]

            models = os.listdir(model_path)
            models.remove("epoch-last.pth")
            sorted_models = [None] * len(models)
            model_idx = [0] * len(models)

            for idx, m in enumerate(models):
                num = m.split(".")[0].split("-")[1]
                model_idx[idx] = num
                sorted_models[idx] = m
            model_idx = np.array([int(i) for i in model_idx])

            down_bound = model_idx >= start_epoch
            up_bound = [True] * len(sorted_models)
            if end_epoch:
                end_epoch = int(end_epoch)
                assert start_epoch < end_epoch
                up_bound = model_idx <= end_epoch
            bound = up_bound * down_bound
            model_slice = np.array(sorted_models)[bound]
            models = [os.path.join(model_path, model) for model in model_slice]
        else:
            models = [
                os.path.join(model_path, 'epoch-%s.pth' % model_indice),
            ]

        results = open(log_file, 'a')
        link_file(log_file, log_file_link)

        for model in models:
            logger.info("Load Model: %s" % model)
            self.val_func = load_model(self.network, model)

            # for name, parameters in self.val_func.named_parameters():
            #     print(name, ':', parameters.size())

            result_line = self.multi_process_evaluation()

            results.write('Model: ' + model + '\n')
            results.write(result_line)
            results.write('\n')
            results.flush()

        results.close()
示例#2
0
 def save_and_link_checkpoint(self, snapshot_dir, log_dir, log_dir_link):
     ensure_dir(snapshot_dir)
     if not osp.exists(log_dir_link):
         link_file(log_dir, log_dir_link)
     current_epoch_checkpoint = osp.join(
         snapshot_dir, 'epoch-{}.pth'.format(self.state.epoch))
     self.save_checkpoint(current_epoch_checkpoint)
     last_epoch_checkpoint = osp.join(snapshot_dir, 'epoch-last.pth')
     link_file(current_epoch_checkpoint, last_epoch_checkpoint)
    def run(self, model_path, model_indice, log_file, log_file_link):
        """Evaluate models."""
        if '.pth' in model_indice:
            models = [model_indice, ]
        else:
            models = [os.path.join(model_path,
                                   'epoch-%s.pth' % model_indice), ]

        results = open(log_file, 'a')
        link_file(log_file, log_file_link)

        for model in models:
            logger.info("Load Model: %s" % model)
            self.val_func = load_model(self.network, model)
            result_line = self.multi_process_evaluation()

            results.write('Model: ' + model + '\n')
            results.write(result_line)
            results.write('\n')
            results.flush()

        results.close()
示例#4
0
 def link_tb(self, source, target):
     ensure_dir(source)
     ensure_dir(target)
     link_file(source, target)