def update_project(self, session, name, data: dict): prj = self.get_project(session, name) if not prj: raise DBError(f"unknown project - {name}") data = data.copy() user_names = data.pop("users", []) for key, value in data.items(): if not hasattr(prj, key): raise DBError(f"unknown project attribute - {key}") setattr(prj, key, value) users = [] # self._find_or_create_users(session, user_names) prj.users.clear() prj.users.extend(users) self._upsert(session, prj, ignore=True)
def _upsert(session, obj, ignore=False): try: session.add(obj) session.commit() except SQLAlchemyError as err: session.rollback() cls = obj.__class__.__name__ logger.warning(f"conflict adding {cls}, {err}") if not ignore: raise DBError(f"duplicate {cls} - {err}") from err
def _find_or_create_users(self, session, user_names): users = list(self._query(session, User).filter(User.name.in_(user_names))) new = set(user_names) - {user.name for user in users} if new: for name in new: user = User(name=name) session.add(user) users.append(user) try: session.commit() except SQLAlchemyError as err: session.rollback() raise DBError(f"add user: {err}") from err return users
def read_artifact(self, session, key, tag="", iter=None, project=""): project = project or config.default_project uid = self._resolve_tag(session, Artifact, project, tag) if iter: key = "{}-{}".format(iter, key) query = self._query(session, Artifact, key=key, project=project) if uid: query = query.filter(Artifact.uid == uid) else: # Select by last updated max_updated = session.query(func.max(Artifact.updated)).filter( Artifact.project == project, Artifact.key == key) query = query.filter(Artifact.updated.in_(max_updated)) art = query.one_or_none() if not art: raise DBError(f"Artifact {key}:{tag}:{project} not found") return art.struct
def update_run(self, session, updates: dict, uid, project="", iter=0): project = project or config.default_project run = self._get_run(session, uid, project, iter) if not run: raise DBError(f"run {uid}:{project} not found") struct = run.struct for key, val in updates.items(): update_in(struct, key, val) run.struct = struct new_state = run_state(struct) if new_state: run.state = new_state start_time = run_start_time(struct) if start_time: run.start_time = start_time run.labels.clear() for name, value in run_labels(struct).items(): lbl = Run.Label(name=name, value=value, parent=run.id) run.labels.append(lbl) session.merge(run) session.commit() self._delete_empty_labels(session, Run.Label)
def _transform_run_db_error(func, *args, **kwargs): try: return func(*args, **kwargs) except RunDBError as exc: raise DBError(exc.args)
def read_run(self, session, uid, project=None, iter=0): project = project or config.default_project run = self._get_run(session, uid, project, iter) if not run: raise DBError(f"Run {uid}:{project} not found") return run.struct