def create_app(): """ 建立web应用 url: http://flask.pocoo.org/docs/1.0/quickstart/ :return: """ flask_app = Flask(__name__) with flask_app.app_context(): # 项目内部配置 LOGGER = get_logger("Liuli API") mongodb_base = MongodbManager.get_mongo_base( mongodb_config=Config.MONGODB_CONFIG) flask_app.config["app_config"] = Config flask_app.config["app_logger"] = LOGGER flask_app.config["mongodb_base"] = mongodb_base LOGGER.info(f"server({Config.API_VERSION}) started successfully :)") # 注册相关蓝图 flask_app.register_blueprint(bp_api_v1) flask_app.register_blueprint(bp_rss) flask_app.register_blueprint(bp_backup) # 初始化JWT flask_app.config["JWT_SECRET_KEY"] = Config.JWT_SECRET_KEY _ = JWTManager(flask_app) return flask_app
def __init__(self, desc): assert desc == 'party' or desc == 'client' or desc == 'default' self.desc = desc # A dictionary of Config Parameters self.config = get_default_argument(desc=self.desc) # Verifiable Secret Sharing self.vss = self.config['vss'] self.project_dir = self.config['project_dir'] if self.config['project_dir'] != "" \ else str(BASE_DIR) self.project_log = self.config["project_log"] if not exists(self.project_log): self.project_log = join(os.path.dirname(self.project_dir), 'logs', 'log.txt') self.create_dir(os.path.dirname(self.project_log)) # logger interface self.isDebug = self.config['debug'] self.logger = get_logger(self.desc, self.project_log, self.isDebug) if self.config['config'] is not None: with open(self.config['config']) as config_file: import yaml config_content = yaml.safe_load(config_file) self.partyServers = [(x['host'], x['port']) for x in config_content['servers']] else: self.partyServers = [("localhost", 8000), ("localhost", 8001), ("localhost", 8002), ("localhost", 8003), ("localhost", 8004), ("localhost", 8005), ("localhost", 8006), ("localhost", 8007), ("localhost", 8008), ("localhost", 8009)] self.max_nums_server = len(self.partyServers) if self.desc != 'client': self.nums_server = self.config['nums_server'] assert self.nums_server <= self.max_nums_server if self.desc == 'client': self.nums_party = self.config['nums_party'] assert self.nums_party <= self.max_nums_server self.init_rng(seed=0) warnings.filterwarnings('ignore') self.q = 2**12 self.p = random_prime(self.q, proof=False, lbound=self.q - 3) self.g = self.largest_prime_factor(self.p - 1) self.zp = GF(self.p) # Finite Field
def __init__(self, seed: int, logdir: pathlib.Path, datadir: pathlib.Path, set_name: str, device: torch.device, remote: bool = False, **kwargs): self.config = kwargs self.seed = seed self.logdir: pathlib.Path = logdir self.datadir: pathlib.Path = datadir self.results_hdf5_path: pathlib.Path = self.logdir / "learning_results.hdf5" self.logger = get_logger("experiment") self.set_name = set_name self.hash = hash_dict(kwargs) self.save_hash() self.device = device self.task_path: TaskPath = TaskPath(self.seed, self.logdir, self.datadir, **kwargs["task_path"]) self.supervised_learning = SupervisedLearning( seed=self.seed, logdir=self.logdir, device=self.device, results_hdf5_path=self.results_hdf5_path, remote=remote) self.inference = Inference(seed=self.seed, logdir=self.logdir, device=self.device, results_hdf5_path=self.results_hdf5_path, remote=remote) self.results_plotter = ResultsPlotter( results_hdf5_path=self.results_hdf5_path, logdir=self.logdir, remote=remote, **self.config.get("visualization", {}))
def __init__(self, desc="default"): assert desc == 'data' or \ desc == 'train' or \ desc == 'test' or \ desc == 'generate' or \ desc == 'default' self.desc = desc # A dictionary of Config Parameters self.config = get_default_argument(desc=self.desc) self.project_dir = self.config['project_dir'] if self.config['project_dir'] != "" \ else str(BASE_DIR) self.project_log = self.config["project_log"] if not exists(self.project_log): self.project_log = join(os.path.dirname(self.project_dir), 'logs', 'log.txt') create_dir(os.path.dirname(self.project_log)) # logger interface self.isDebug = self.config['debug'] self.logger = get_logger(self.desc, self.project_log, self.isDebug) if self.config['config'] is not None: with open(self.config['config']) as config_file: import yaml config_content = yaml.safe_load(config_file) pass else: pass self.data = self.config["data"] init_rng(seed=0) warnings.filterwarnings('ignore')
import numpy as np import pathlib from PIL import Image from scipy.spatial.transform import Rotation import tifffile import torch import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset as TorchDataset, Subset import torchvision from typing import * import warnings from src.enums.channel_enum import ChannelEnum from src.utils.log import get_logger logger = get_logger("base_dataset") class BaseDataset(ABC): def __init__(self, config: dict, dataset_path: pathlib.Path, purpose: str = None, transform: Optional[Callable] = None): self.config = config self.purpose = purpose self.dataset_path = dataset_path self.transform = transform self.min: float = None
from src.learning.learning_classes.base_learning import BaseLearning from src.learning.tasks import Task from src.utils.log import get_logger logger = get_logger("inference") class Inference(BaseLearning): def __init__(self, **kwargs): super().__init__(logger=logger, **kwargs) def run(self, task: Task): self.task = task self.set_model(task.model_to_infer, pick_optimizer=False) return self.infer()
import numpy as np import pathlib from pytorch_msssim import ssim from typing import * import warnings import torch from torch.nn import functional as F from torch.utils.data import Dataset import torchvision from src.dataloaders.dataloader_meta_info import DataloaderMetaInfo from src.enums import * from src.utils.log import get_logger logger = get_logger("loss") class Loss(ABC): def __init__(self, logdir: pathlib.Path, **kwargs): self.config = kwargs self.logdir = logdir self.report_frequency: int = self.config["report_frequency"] self.batch_results = [] # must be reset at the beginning of each epoch self.batch_sizes = [] self.epoch_losses = {"train": {}, "val": {}, "test": {}} self.purpose: Optional[str] = None self.epoch: Optional[int] = None
from progress.bar import Bar import torch from src.dataloaders.dataloader_meta_info import DataloaderMetaInfo from src.enums import * from src.learning.learning_classes.base_learning import BaseLearning from src.learning.tasks import Task from src.utils.log import get_logger logger = get_logger("supervised_learning") class SupervisedLearning(BaseLearning): def __init__(self, **kwargs): super().__init__(logger=logger, **kwargs) def train(self, task: Task): self.set_task(task) self.set_model(task.model_to_train) return self.train_epochs() def train_epoch(self, epoch) -> None: self.model.train() dataloader = self.task.labeled_dataloader.dataloaders['train'] dataloader_meta_info = DataloaderMetaInfo(dataloader) with self.task.loss.new_epoch(epoch, "train", dataloader_meta_info=dataloader_meta_info): progress_bar = Bar(f"Train epoch {epoch} of task {self.task.uid}", max=len(dataloader)) for batch_idx, data in enumerate(dataloader): self.optimizer.zero_grad()
import pathlib import csv import copy from typing import Dict import torch from src.utils.log import get_logger logger = get_logger("controller") class Controller: def __init__(self, **kwargs): """ The Controller manages when training should be terminated. Its states contain model dicts and validation losses for each epoch. It wil stop the iteration if the training has converged - this occurs for example when the validation loss has not improved for the last 'max_num_better_results'. It will also stop the iteration if the maximum number of epochs has been reached. Finally it can return the best state after the iteration to reset the model to the state in which it achieved the best validation loss. We include a boolean option 'epoch_stop' that bases the learning stopping only on the number of epochs and returns the model of the last trained epoch instead of the model with the best validation loss. :param kwargs: """ # If "get()" does not find keyword, the value is None self.max_num_epochs = kwargs.get("max_num_epochs") self.max_num_better_results = kwargs.get("max_num_better_results") self.epoch_stop: bool = kwargs.get("epoch_stop", False) assert (self.epoch_stop is True and self.max_num_epochs is not None and self.max_num_better_results is None) \ or (self.max_num_better_results is not None and self.epoch_stop is False) \ or (self.max_num_epochs is not None and self.epoch_stop is False)
import pandas as pd import pathlib from progress.bar import Bar import seaborn as sns import torch from typing import * import warnings from .sample_plotter import draw_error_uncertainty_plot, draw_solutions_plot, draw_traversability_plot, \ draw_qualitative_comparison_plot, draw_occ_mask_plot from src.enums import * from src.learning.loss.loss import masked_loss_fct, mse_loss_fct, l1_loss_fct, psnr_loss_fct from src.utils.log import get_logger from src.visualization.live_inference_plotter import plot_live_inference logger = get_logger("results_plotter") sns.set(style="whitegrid") class ResultsPlotter: def __init__(self, results_hdf5_path: pathlib.Path, logdir: pathlib.Path, remote: bool = False, **kwargs): self.config = kwargs self.logdir = logdir self.results_hdf5_path = results_hdf5_path self.remote = remote
import torch.autograd.profiler as profiler from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LambdaLR, ReduceLROnPlateau, MultiplicativeLR from src.dataloaders.dataloader_meta_info import DataloaderMetaInfo from src.enums import * from src.learning.controller.controller import Controller from src.learning.loss.loss import Loss from src.learning.models import pick_model from src.learning.models.baseline.base_baseline_model import BaseBaselineModel from src.learning.models.baseline.lsq_plane_fit_baseline import LsqPlaneFitBaseline from src.learning.models.unet.unet_parts import VGG16FeatureExtractor from src.learning.tasks import Task from src.traversability.traversability_assessment import TraversabilityAssessment from src.utils.log import get_logger logger = get_logger("base_learning") class BaseLearning(ABC): def __init__(self, seed: int, logdir: pathlib.Path, device: torch.device, logger: logging.Logger, results_hdf5_path: pathlib.Path, remote: bool = False, **kwargs): super().__init__() self.seed = seed self.logdir = logdir
import json import pathlib from typing import * import torch from src.enums.task_type_enum import TaskTypeEnum from src.dataloaders.dataloader import Dataloader from src.learning.loss.loss import Loss from src.utils.log import get_logger logger = get_logger("task") class Task: def __init__(self, uid: int, logdir: pathlib.Path, **kwargs): self.uid: int = uid self.type = TaskTypeEnum(kwargs["task_type"]) self.config = kwargs self.logdir: pathlib.Path = logdir self.loss: Optional[Loss] = None self.config = kwargs self.name = json.dumps(self.config) self.labeled_dataloader: Optional[Dataloader] = None self.unlabeled_dataloader: Optional[Dataloader] = None self.inference_dataloader: Optional[Dataloader] = None
from copy import deepcopy import pathlib from typing import Dict, List, Optional import torch from src.enums.task_type_enum import TaskTypeEnum from src.dataloaders.dataloader import Dataloader from .task import Task from src.utils.log import get_logger logger = get_logger("task_path") class TaskPath: """ Task paths are iterables that return a task and increasing uid at each iteration. The TaskPath also measures the runtime of each of its path by logging the time between calls itself (the task iterator). """ def __init__(self, seed: int, logdir: pathlib.Path, datadir: pathlib.Path, **kwargs): self.config = kwargs self.seed = seed self.task_configs = kwargs["tasks"] self.default_values = kwargs["defaults"] self.idx: int = 0 self.logdir: pathlib.Path = logdir self.datadir: pathlib.Path = datadir
from dash import html import h5py import matplotlib.pyplot as plt import numpy as np import pathlib import plotly import plotly.graph_objects as go import random import time import torch from typing import * from src.enums import * from src.utils.log import get_logger logger = get_logger("live_inference_plotter") def plot_live_inference(purpose_hdf5_group: h5py.Group): data_hdf5_group = purpose_hdf5_group["data"] occ_dem_dataset = data_hdf5_group[ChannelEnum.OCC_DEM.value] comp_dem_dataset = data_hdf5_group[ChannelEnum.COMP_DEM.value] occ_dems = np.array(occ_dem_dataset) comp_dems = np.array(comp_dem_dataset) app = dash.Dash(__name__) app.layout = html.Div([ dcc.Graph(id='live-graph', animate=True), dcc.Interval(id='graph-update', interval=500,