Example #1
0
    def __init__(self, github_obj, repo_full_name, cache_filename):
        """
        Initialize cache and reference github repository issues
        """

        self.github = github_obj
        self.repo_full_name = repo_full_name
        self.shelf = shelve.open(filename=cache_filename,
                                 protocol=self.protocol,
                                 writeback=True)

        # Avoid exceeding rate-limit per hour
        requests = self.github.rate_limiting[1]  # requests per hour
        period = 60.0 * 60.0  # one hour in seconds
        sleeptime = period / requests
        self.pre_fetch_partial = Partial(time.sleep, sleeptime)
        # self.pre_fetch_partial = None # cheat-mode enable (no delays)

        repo_cache_key = 'repo_%s' % self.repo_full_name
        # get_repo called same way throughout instance life
        cache_get_partial = Partial(self.shelf.__getitem__, repo_cache_key)
        cache_set_partial = Partial(self.shelf.__setitem__, repo_cache_key)
        cache_del_partial = Partial(self.shelf.__delitem__, repo_cache_key)
        fetch_partial = Partial(self.github.get_repo, self.repo_full_name)
        # Callable instance retrieves cached or fetched value for key
        self.get_repo = self.cache_class(self.github, cache_get_partial,
                                         cache_set_partial, cache_del_partial,
                                         self.pre_fetch_partial, fetch_partial)
        super(GithubIssuesBase, self).__init__()
Example #2
0
    def gh_pr_commit_authors(self, gh_issue):
        """
        Return list of commit author e-mail addresses for a pull-request
        """

        if GithubIssues.gh_issue_is_pull(gh_issue):
            num = gh_issue.number
            repo = self.get_repo()
            cache_key = 'repo_%s_pull_%s' % (self.repo_full_name, str(num))
            fetch_partial = Partial(repo.get_pull, num)
            pull = self.get_gh_obj(cache_key, fetch_partial)
            if pull.commits is None or pull.commits < 1:
                return None  # No commits == no commit authors

            cache_key = 'repo_%s_pull_%s_commits' % (self.repo_full_name,
                                                     str(num))
            fetch_partial = Partial(pull.get_commits)
            authors = set()
            for commit in self.get_gh_obj(cache_key, fetch_partial):
                # Referencing commit author requires a request, cache it.
                author_cache_key = cache_key + '_%s_author' % str(commit.sha)
                author_fetch_partial = Partial(getattr, commit, 'author')
                try:
                    author_obj = self.get_gh_obj(author_cache_key,
                                                 author_fetch_partial)
                except:
                    # clean up commit list cache entry also
                    self.clean_cache_entry(cache_key)
                    raise  # original exception
                # Retrieve e-mail from git commit object
                if author_obj is None:
                    # Referencing git commit requires a request, cache it
                    gitcommit_cache_key = (cache_key +
                                           '_%s_gitcommit' % str(commit.sha))
                    gitcommit_fetch_partial = Partial(getattr, commit,
                                                      'commit')  # git commit
                    try:
                        gitcommit = self.get_gh_obj(gitcommit_cache_key,
                                                    gitcommit_fetch_partial)
                    except:
                        # Need to clean commit and gitcommit entries
                        self.clean_cache_entry(cache_key)
                        self.clean_cache_entry(gitcommit_cache_key)
                        raise
                    authors.add(gitcommit.author.email)
                else:  # Author is a github user
                    authors.add(author_obj.login)
            return authors
        return None  # not a pull request
Example #3
0
 def get_gh_obj(self, cache_key, fetch_partial):
     """
     Helper to get object possibly from cache and update counters
     """
     cache_get_partial = Partial(self.shelf.__getitem__, cache_key)
     cache_set_partial = Partial(self.shelf.__setitem__, cache_key)
     cache_del_partial = Partial(self.shelf.__delitem__, cache_key)
     # Callable instance could change every time
     get_obj = GithubCache(self.github, cache_get_partial,
                           cache_set_partial, cache_del_partial,
                           self.pre_fetch_partial, fetch_partial)
     result = get_obj()
     self._cache_hits += get_obj.cache_hits
     self._cache_misses += get_obj.cache_misses
     return result  # DOES NOT SYNC DATA!
Example #4
0
 def get_gh_user(self, login):
     cache_key = 'github_user_%s' % login
     fetch_partial = Partial(self.github.get_user, login)
     try:
         return self.get_gh_obj(cache_key, fetch_partial)
     except KeyError:
         raise ValueError('login %s is not a valid github user' % login)
Example #5
0
 def get_gh_label(self, name):
     repo = self.get_repo()
     cache_key = str('repo_%s_label_%s' % (self.repo_full_name, name))
     fetch_partial = Partial(repo.get_label, name)
     try:
         return self.get_gh_obj(cache_key, fetch_partial)
     except KeyError:
         raise ValueError('label %s is not valid for repo %s' %
                          (name, self.repo_full_name))
Example #6
0
    def Refresh(self, extraArg=None):
        '''
        Refreshes the animations UI
        '''
        # Process select all checkbox value
        selected = False
        if Cmds.checkBox(self.selectAll, q=True, exists=True):
            selected = Cmds.checkBox(self.selectAll, q=True, value=True)

        if selected == True:
            for animation in self.sequencer.Animations.values():
                if animation.Selected == False:
                    selected = False
                    break

        # Clear all UI items
        for animationId in self.AnimationUIs:
            self.AnimationUIs[animationId].Destroy()
        self.AnimationUIs = {}

        # Animations Layout
        Cmds.setParent(self.windowLayout)
        if Cmds.columnLayout(self.animationsLayout, exists=True):
            Cmds.deleteUI(self.animationsLayout)

        Cmds.columnLayout(self.animationsLayout, adjustableColumn=True)

        self.CreateSeparator(self.animationsLayout, 'out')

        # Header row
        Cmds.rowLayout(numberOfColumns=4,
                       columnWidth4=[35, 185, 50, 50],
                       columnAttach=[1, 'left', 8])
        Cmds.checkBox(self.selectAll,
                      label='',
                      cc=Partial(self.SelectAll),
                      value=selected)
        Cmds.text(label=' Animation Name')
        Cmds.text(label=' Start')
        Cmds.text(label=' End')

        self.CreateSeparator(self.animationsLayout, 'out')

        # Add back based on order

        #print self.sequencer.Ordering
        #return

        for orderId in range(len(self.sequencer.Ordering)):
            animationId = self.sequencer.Ordering[orderId]
            self.CreateAnimationEntry(self.sequencer.Animations[animationId])

        # Update
        self.Update()

        # Save
        self.Save()
Example #7
0
 def gh_pr_commits(self, gh_issue):
     """
     Retrieves the number of commits on a pull-request, None if not a pull.
     """
     if GithubIssues.gh_issue_is_pull(gh_issue):
         num = gh_issue.number
         repo = self.get_repo()
         cache_key = 'repo_%s_pull_%s' % (self.repo_full_name, str(num))
         fetch_partial = Partial(repo.get_pull, num)
         pull = self.get_gh_obj(cache_key, fetch_partial)
         return pull.commits
     return None  # not a pull request
Example #8
0
 def __getitem__(self, key):
     """
     Return a standardized dict of github issue unless NoEnumerate=True
     """
     repo = self.get_repo()
     # Enforce uniform key string
     cache_key = self.get_issue_cache_key(key)
     fetch_partial = Partial(repo.get_issue, int(key))
     item = self.get_gh_obj(cache_key, fetch_partial)
     # No exception raised, update cache on disk
     self.shelf.sync()
     return item
Example #9
0
 def gh_issue_comment_authors(self, gh_issue):
     """
     Retrieve a list of comment author e-mail addresses
     """
     if gh_issue.comments > 0:
         num = gh_issue.number
         cache_key = ('repo_%s_issue_%s_comments' %
                      (self.repo_full_name, num))
         fetch_partial = Partial(gh_issue.get_comments)
         authors = set()
         for comment in self.get_gh_obj(cache_key, fetch_partial):
             # Referencing user attribute requires a request, so cache it
             user_cache_key = cache_key + '_%s_user' % comment.id
             user_fetch_partial = Partial(getattr, comment, 'user')
             try:
                 user = self.get_gh_obj(user_cache_key, user_fetch_partial)
             except:
                 # Also clean up comments cache
                 self.clean_cache_entry(cache_key)
                 raise  # original exception
             authors.add(user.email)
         return authors
     else:
         return None
Example #10
0
def multiproc_letter_body(openings, closings, src_dirs, dst_file, processes=1):

    load_f = Partial(letters_body_to_file,
                     openings=openings,
                     closings=closings,
                     dst_file=dst_file)

    file_paths = []
    for letters_dir in src_dirs:
        file_paths.extend(
            [os.path.join(letters_dir, fn) for fn in os.listdir(letters_dir)])

    def chunks(file_paths, n):
        """Yield successive n-sized chunks from lst."""
        for i in range(0, len(file_paths), n):
            yield [open(path).read() for path in file_paths[i:i + n]]

    chunk_size = len(file_paths) // processes
    with Pool(processes=processes) as procs:

        procs.map(func=load_f,
                  iterable=chunks(file_paths=file_paths, n=chunk_size))
Example #11
0
    def __init__(self,
                 model: nn.Module or Dict[str, nn.Module],
                 optimizer: Optional[Partial or Optimizer
                                     or Dict[str, Optimizer]],
                 loss_f: Optional[Callable or Dict[str, Callable]],
                 *,
                 callbacks: Optional[Iterable[Callback]] = None,
                 scheduler: Optional[Partial or Scheduler
                                     or Dict[str, Scheduler]] = None,
                 update_scheduler_by_epoch: bool = True,
                 device: Optional[torch.device or str] = None,
                 verb: bool = True,
                 use_cudnn_benchmark: bool = True,
                 use_cuda_nonblocking: bool = False,
                 use_horovod: bool = False,
                 logger=None,
                 use_sync_bn: bool = False,
                 tqdm_ncols: int = 80,
                 **kwargs):

        if logger is None:
            logger = get_logger(__name__)
        self.logger = logger

        if device is None:
            self.device = torch.device(
                GPU) if torch.cuda.is_available() else torch.device(CPU)
        else:
            self.device = device

        if use_horovod and not is_horovod_available():
            raise RuntimeError('horovod is not available!')

        if is_distributed():
            if use_sync_bn:
                model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

            rank = get_local_rank()
            torch.cuda.set_device(rank)
            if get_global_rank() > 0:
                # to avoid overwriting
                verb = False

        if isinstance(model, nn.Module):
            self.model = model
        elif isinstance(model, dict):
            self.model = nn.ModuleDict(model)
        else:
            raise TypeError(
                f"Unknown type for `model`. Expected nn.Module or Dict[str, Module] but got {type(model)}"
            )

        if GPU in str(self.device):
            self.model.to(self.device)
            torch.backends.cudnn.benchmark = use_cudnn_benchmark
            self._cuda_nonblocking = use_cuda_nonblocking
            self.logger.debug(
                f"cuda: True, cudnn.benchmark: {use_cudnn_benchmark}, "
                f"cuda.nonblocking: {use_cuda_nonblocking}")
        else:
            self._cuda_nonblocking = False
            # usually, this is not expected
            self.logger.info(
                f"cuda: False (torch.cuda.is_available()={torch.cuda.is_available()})"
            )

        if not use_horovod and is_distributed():
            self.model = nn.parallel.DistributedDataParallel(self.model,
                                                             device_ids=[rank])

        if isinstance(self.model,
                      nn.parallel.DistributedDataParallel) or isinstance(
                          self.model, nn.DataParallel):
            self.accessible_model = self.model.module
        else:
            self.accessible_model = self.model

        self.optimizer = None
        self.scheduler = None
        self._callbacks = None
        self.update_scheduler_by_epoch = update_scheduler_by_epoch
        self._set_optimizer(optimizer)
        self._set_scheduler(scheduler)
        self._set_callbacks(callbacks)

        if use_horovod:
            import horovod.torch as hvd

            hvd.broadcast_parameters(self.model.state_dict(), root_rank=0)
            hvd.broadcast_optimizer_state(self.optimizer, root_rank=0)
            self.optimizer = hvd.DistributedOptimizer(
                self.optimizer, named_parameters=self.model.named_parameters())

        self.loss_f = loss_f
        self._verb = verb

        # called via property
        # _step and _epoch are set to -1 because they are incremented before each iteration and epoch!
        self._step = -1
        self._epoch = -1
        self._is_train = True
        # to nest, leave=False (https://github.com/tqdm/tqdm/blob/master/examples/simple_examples.py#L19)
        self._tqdm = Partial(tqdm, ncols=tqdm_ncols,
                             leave=False) if verb else lambda x: x

        _map_base = {
            MODEL: self.accessible_model,
            OPTIMIZER: self.optimizer,
            SCHEDULER: self.scheduler,
            TRAINER: self
        }
        self._iteration_map = Map(**_map_base.copy())
        self._epoch_map = Map(**_map_base.copy())
        self._all_map = Map(**_map_base.copy())

        for k, v in kwargs.items():
            if hasattr(self, k):
                raise AttributeError(f"{self} already has {k}")
            if torch.is_tensor(v):
                v = v.to(self.device)
            if isinstance(v, nn.Module):
                v.to(self.device)
            setattr(self, k, v)

        self._callbacks.before_all(self._all_map)
Example #12
0
    def __init__(self,
                 model: nn.Module or Dict[str, nn.Module],
                 optimizer: Optional[Partial or Optimizer
                                     or Dict[str, Optimizer]],
                 loss_f: Optional[Callable or Dict[str, Callable]] = None,
                 *,
                 reporters: Optional[_ReporterBase
                                     or List[_ReporterBase]] = None,
                 scheduler: Optional[Partial or Scheduler
                                     or Dict[str, Scheduler]] = None,
                 update_scheduler_by_epoch: bool = True,
                 device: Optional[torch.device or str] = None,
                 verb: bool = True,
                 use_cudnn_benchmark: bool = True,
                 use_cuda_nonblocking: bool = False,
                 use_horovod: bool = False,
                 logger=None,
                 use_sync_bn: bool = False,
                 tqdm_ncols: int = 80,
                 **kwargs):

        if kwargs.get("callbacks"):
            raise DeprecationWarning(
                "callback is deprecated, if you need, use homura before v2020.8"
            )

        self.logger = logger or get_logger(__name__)

        self.device = device or (torch.device(
            GPU) if torch.cuda.is_available() else torch.device(CPU))

        # setup for distributed
        self._use_sync_bn = use_sync_bn
        if is_distributed():
            if self._use_sync_bn:
                model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
                self.logger.info(
                    "BNs of model are converted to nn.SyncBatchNorm")

            rank = get_local_rank()
            torch.cuda.set_device(rank)
            if get_global_rank() > 0:
                # to avoid overwriting
                verb = False

        # setup model
        if isinstance(model, nn.Module):
            self.model = model
        elif isinstance(model, dict):
            self.model = nn.ModuleDict(model)
        else:
            raise TypeError(
                f"Unknown type for `model`. Expected nn.Module or Dict[str, Module], but got {type(model)}"
            )

        if GPU in str(self.device):
            self.model.to(self.device)
            torch.backends.cudnn.benchmark = use_cudnn_benchmark
            self._cuda_nonblocking = use_cuda_nonblocking
            self.logger.debug(
                f"cuda: True, cudnn.benchmark: {use_cudnn_benchmark}, "
                f"cuda.nonblocking: {use_cuda_nonblocking}")
        else:
            self._cuda_nonblocking = False
            # usually, this is not expected
            self.logger.info(
                f"cuda: False (torch.cuda.is_available()={torch.cuda.is_available()})"
            )

        if not use_horovod and is_distributed():
            self.model = nn.parallel.DistributedDataParallel(self.model,
                                                             device_ids=[rank])

        # self.accessible_model is useful for e.g., checkpointing
        if isinstance(self.model,
                      nn.parallel.DistributedDataParallel) or isinstance(
                          self.model, nn.DataParallel):
            self.accessible_model = self.model.module
        else:
            self.accessible_model = self.model

        self.loss_f = loss_f
        self._verb = verb

        # setup optimizer and scheduler
        self.optimizer = optimizer
        self.scheduler = scheduler
        self._update_scheduler_by_epoch = update_scheduler_by_epoch
        self.set_optimizer()
        self.set_scheduler()

        if use_horovod:
            if not is_horovod_available():
                raise RuntimeError("horovod is not available!")
            import horovod.torch as hvd

            hvd.broadcast_parameters(self.model.state_dict(), root_rank=0)
            hvd.broadcast_optimizer_state(self.optimizer, root_rank=0)
            self.optimizer = hvd.DistributedOptimizer(
                self.optimizer, named_parameters=self.model.named_parameters())

        if reporters is not None and not isinstance(reporters, Iterable):
            reporters = [reporters]
        reporters = reporters or []

        if not any([isinstance(rep, TQDMReporter) for rep in reporters]):
            # if reporters not contain TQDMReporter
            reporters.append(TQDMReporter(ncols=tqdm_ncols))
        self.reporter = ReporterList(reporters)

        # called via property
        # _step and _epoch are set to -1 because they are incremented before each iteration and epoch
        self._step = -1
        self._epoch = -1
        self._is_train = True

        # to nest, leave=False (https://github.com/tqdm/tqdm/blob/master/examples/simple_examples.py#L19)
        self._tqdm = lambda x: x
        if verb:
            self._tqdm = Partial(tqdm, ncols=tqdm_ncols, leave=False)
            _set_tqdm_print()

        for k, v in kwargs.items():
            if hasattr(self, k):
                raise AttributeError(f"{self} already has {k}")
            if torch.is_tensor(v):
                v = v.to(self.device)
            if isinstance(v, nn.Module):
                v.to(self.device)
            setattr(self, k, v)
            self.logger.debug(f"trainer sets {k} as a new attribute")
Example #13
0
    def search(self, criteria):
        """
        Return a list of issue-numbers that match a search criteria.

        :param criteria: Dictionary of search terms
            state - str - 'open', 'closed'
            assignee - list of str (login), "none" or "*"
            mentioned - str (login)
            labels - list of str (label name)
            sort - str - 'created', 'updated', 'comments'
            direction - str - 'asc', 'desc'
            since - datetime.datetime
        """
        valid_criteria = {}
        # use search dictionary to form hash for cached results
        search_cache_key = 'issue_search'
        # Validate & transform criteria
        if criteria.has_key('state'):
            state = str(criteria['state'])
            if state not in ('open', 'closed'):
                raise ValueError("'state' criteria must be 'open' or 'closed'")
            valid_criteria['state'] = state
            search_cache_key = '%s_%s' % (search_cache_key, state)

        if criteria.has_key('assignee'):
            assignee = str(criteria['assignee'])
            search_cache_key = '%s_%s' % (search_cache_key, assignee)
            if assignee in ('none', '*'):
                valid_criteria['assignee'] = assignee
            else:
                # returns github.NamedUser.NamedUser
                valid_criteria['assignee'] = self.get_gh_user(assignee)

        if criteria.has_key('mentioned'):
            mentioned = str(criteria['assignee'])
            search_cache_key = '%s_%s' % (search_cache_key, mentioned)
            if mentioned in ('none', '*'):
                valid_criteria['mentioned'] = mentioned
            else:
                # returns github.NamedUser.NamedUser
                valid_criteria['mentioned'] = self.get_gh_user(mentioned)

        if criteria.has_key('labels'):
            labels = criteria['labels']
            if not isinstance(labels, list):
                raise ValueError("'lables' criteria must be a list")
            valid_criteria['labels'] = []
            for name in labels:
                search_cache_key = '%s_%s' % (search_cache_key, labels)
                valid_criteria['labels'].append(self.get_gh_label(str(name)))

        if criteria.has_key('sort'):
            sort = str(criteria['sort'])
            if sort not in ('created', 'updated', 'comments'):
                raise ValueError("'sort' criteria must be 'created', 'updated'"
                                 ", 'comments'")
            valid_criteria['sort'] = sort
            search_cache_key = '%s_%s' % (search_cache_key, sort)

        if criteria.has_key('direction'):
            direction = str(criteria['direction'])
            if direction not in ('asc', 'desc'):
                raise ValueError("'direction' criteria must be 'asc', 'desc'")
            valid_criteria['direction'] = direction
            search_cache_key = '%s_%s' % (search_cache_key, direction)

        if criteria.has_key('since'):
            since = criteria['since']
            if not isinstance(since, datetime.datetime):
                raise ValueError("'since' criteria must be a "
                                 "datetime.datetime")
            # second and milisecond not useful to search or cache
            since = datetime.datetime(year=since.year,
                                      month=since.month,
                                      day=since.day,
                                      hour=since.hour,
                                      minute=since.minute,
                                      second=0,
                                      microsecond=0)
            search_cache_key = '%s_%s' % (search_cache_key, since.isoformat())
            valid_criteria['since'] = since

        # Do not perform search operation unless no cached results
        # or cached results have expired
        fetch_partial = Partial(self.make_search_results, valid_criteria)
        # This could take an arbitrarily LONG time
        return self.get_gh_obj(search_cache_key, fetch_partial)
Example #14
0
    def __init__(self,
                 model: nn.Module or Dict[str, nn.Module],
                 optimizer: Optional[Partial or Optimizer
                                     or Dict[str, Optimizer]],
                 loss_f: Optional[Callable or Dict[str, Callable]] = None,
                 *,
                 reporters: Optional[_ReporterBase
                                     or List[_ReporterBase]] = None,
                 scheduler: Optional[Partial or Scheduler
                                     or Dict[str, Scheduler]] = None,
                 device: Optional[torch.device or str] = None,
                 quiet: bool = True,
                 disable_cudnn_benchmark: bool = False,
                 disable_cuda_nonblocking: bool = False,
                 logger=None,
                 use_sync_bn: bool = False,
                 tqdm_ncols: int = 120,
                 debug: bool = False,
                 **kwargs):

        if kwargs.get("update_scheduler_by_epoch"):
            raise DeprecationWarning(
                "update_scheduler_by_epoch is deprecated, users need to step")

        if kwargs.get("callbacks"):
            raise DeprecationWarning(
                "callback is deprecated, if you need, use homura before v2020.8"
            )

        self.logger = logger or get_logger(__name__)
        self.device = device or (torch.device(
            GPU) if torch.cuda.is_available() else torch.device(CPU))
        self._is_debug = debug

        if self._is_debug:
            self.logger.warning(
                "Trainer is set to be debug mode, which may affect the performance"
            )
            set_verb_level("debug")

        # setup for distributed
        self._use_sync_bn = use_sync_bn
        if is_distributed():
            if self._use_sync_bn:
                model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
                self.logger.info(
                    "BNs of model are converted to nn.SyncBatchNorm")

            rank = get_local_rank()
            torch.cuda.set_device(rank)
            if get_global_rank() > 0:
                # to avoid overwriting
                quiet = True

        self.loss_f = loss_f
        self._verbose = not quiet

        # setup model
        if isinstance(model, nn.Module):
            self.model = model
        elif isinstance(model, dict):
            self.model = nn.ModuleDict(model)
            self.logger.debug(f"model is nn.ModuleDict of {self.model.keys()}")
        else:
            raise TypeError(
                f"Unknown type for `model`. Expected nn.Module or Dict[str, Module], but got {type(model)}"
            )

        if GPU in str(self.device):
            self.model.to(self.device)
            torch.backends.cudnn.benchmark = not disable_cudnn_benchmark
            self._cuda_nonblocking = not disable_cuda_nonblocking
            self.logger.debug(
                f"cuda: True, cudnn.benchmark: {not disable_cudnn_benchmark}, "
                f"cuda.nonblocking: {not disable_cuda_nonblocking}")
        else:
            self._cuda_nonblocking = False
            # usually, this is not expected
            self.logger.info(
                f"cuda: False (torch.cuda.is_available()={torch.cuda.is_available()})"
            )

        if is_distributed():
            self.model = nn.parallel.DistributedDataParallel(self.model,
                                                             device_ids=[rank])
            self.logger.debug(
                f"model converted to DistributedDataParallel at rank={rank}")

        # self.accessible_model is useful for e.g., checkpointing
        if isinstance(self.model,
                      nn.parallel.DistributedDataParallel) or isinstance(
                          self.model, nn.DataParallel):
            self.accessible_model = self.model.module
        else:
            self.accessible_model = self.model

        # setup optimizer and scheduler
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.set_optimizer()
        self.set_scheduler()

        if reporters is not None and not isinstance(reporters, Iterable):
            reporters = [reporters]
        reporters = reporters or []

        if not any([isinstance(rep, TQDMReporter) for rep in reporters]):
            # if reporters not contain TQDMReporter
            reporters.append(TQDMReporter(ncols=tqdm_ncols))
        self.logger.debug(f"reporter is ready: {reporters}")
        self.reporter = ReporterList(reporters)

        # called via property
        # _step and _epoch are set to -1 because they are incremented before each iteration and epoch
        self._step = -1
        self._epoch = -1
        self._is_train = True

        # to nest, leave=False (https://github.com/tqdm/tqdm/blob/master/examples/simple_examples.py#L19)
        self._tqdm = lambda x: x
        if self._verbose:
            self._tqdm = Partial(tqdm, ncols=tqdm_ncols, leave=False)
            set_tqdm_stdout_stderr()
            self.logger.debug("verbose: setup tqdm")
        else:
            self.logger.debug("quiet: no tqdm")

        for k, v in kwargs.items():
            if hasattr(self, k):
                raise AttributeError(f"{self} already has {k}")
            if isinstance(v, torch.Tensor):
                v = v.to(self.device)
            if isinstance(v, nn.Module):
                v.to(self.device)
            setattr(self, k, v)
            self.logger.debug(f"trainer sets {k} as a new attribute")
Example #15
0
    def Create(self):
        '''
        Creates the window
        '''
        if Cmds.window(self.windowName, exists=True):
            Cmds.deleteUI(self.windowName)

        # Main window
        Cmds.window(self.windowName,
                    title=self.windowTitle,
                    widthHeight=[self.width, self.height])
        Cmds.scrollLayout(hst=16, vst=16)
        Cmds.columnLayout(self.windowLayout)

        self.CreateSeparator()

        # Buttons
        Cmds.rowLayout(
            self.buttonsLayout,
            numberOfColumns=6,
            columnWidth6=[30, 48, 55, 75, 75, 75],
            columnAlign6=['left', 'left', 'left', 'left', 'left', 'left'])
        Cmds.button(label='Add',
                    backgroundColor=[0.6, 0.9, 0.6],
                    c=Partial(self.Add))
        Cmds.button(label='Delete',
                    backgroundColor=[0.9, 0.6, 0.6],
                    c=Partial(self.Delete))
        Cmds.button(label='Move Up', c=Partial(self.MoveUp))
        Cmds.button(label='Move Down', c=Partial(self.MoveDown))
        Cmds.button(label='Refresh',
                    backgroundColor=[0.6, 0.6, 0.9],
                    c=Partial(self.Refresh))

        self.CreateSeparator()

        # Tool controls
        Cmds.frameLayout(label="Tool Controls",
                         collapsable=True,
                         collapse=False)
        Cmds.columnLayout(width=self.width - 5)
        Cmds.rowLayout(numberOfColumns=2,
                       columnWidth2=[45, 290],
                       columnAlign2=['left', 'left'])
        Cmds.text(label=' Prefix')
        Cmds.textField(self.prefixTextBox, width=288)

        Cmds.setParent('..')
        Cmds.rowLayout(numberOfColumns=2,
                       columnWidth2=[200, 48],
                       columnAlign2=['left', 'left'])
        Cmds.text(label=' To import from MoveLister')
        Cmds.button(label='Import ML',
                    c=Partial(self.ImportMoveLister),
                    backgroundColor=[0.9, 0.9, 0.8])

        Cmds.setParent('..')
        Cmds.rowLayout(numberOfColumns=2,
                       columnWidth2=[200, 48],
                       columnAlign2=['left', 'left'])
        Cmds.text(label=' To generate multiple playblasts')
        Cmds.button(label='PlayBlast',
                    c=Partial(self.GeneratePlayblast),
                    backgroundColor=[0.9, 0.9, 0.8])

        Cmds.setParent('..')
        Cmds.rowLayout(numberOfColumns=3,
                       columnWidth3=[200, 90, 48],
                       columnAlign3=['left', 'left', 'left'])
        Cmds.text(label=' To export animation list as CSV')
        Cmds.button(label='Export to CSV',
                    c=Partial(self.Export),
                    backgroundColor=[0.9, 0.9, 0.8])
        self.IncludePlayblastLinkCheckBox = Cmds.checkBox(
            label='playblast link')

        Cmds.setParent('..')
        Cmds.rowLayout(numberOfColumns=2,
                       columnWidth2=[200, 48],
                       columnAlign2=['left', 'left'])
        Cmds.text(label=' To generate animation-aware FBX')
        Cmds.button(label='Generate FBX',
                    c=Partial(self.GenerateFbx),
                    backgroundColor=[0.9, 0.9, 0.8])

        self.CreateSeparator()

        # Skyrigger controls
        Cmds.frameLayout(label="Skyrigger and Animation Controls",
                         collapsable=True,
                         collapse=False)
        Cmds.columnLayout(width=self.width - 5)
        Cmds.rowLayout(numberOfColumns=4,
                       columnWidth4=[60, 55, 55, 55],
                       columnAlign4=['left', 'left', 'left', 'left'])
        Cmds.text(label=' Start frame')
        Cmds.textField(self.startFrameTextBox)
        Cmds.text(label='End frame')
        Cmds.textField(self.endFrameTextBox)

        Cmds.setParent('..')
        Cmds.rowLayout(numberOfColumns=2,
                       columnWidth2=[200, 48],
                       columnAlign2=['left', 'left'])
        Cmds.text(label=' To bake the keys')
        Cmds.button(label='Bake Keys',
                    c=Partial(self.BakeKeys),
                    backgroundColor=[0.8, 0.8, 0.9])

        Cmds.setParent('..')
        Cmds.rowLayout(numberOfColumns=2,
                       columnWidth2=[200, 48],
                       columnAlign2=['left', 'left'])
        Cmds.text(label=' To delete the rig controls')
        Cmds.button(label='Delete Rig Controls',
                    c=Partial(self.DeleteRigControls),
                    backgroundColor=[0.9, 0.8, 0.8])

        Cmds.setParent('..')
        Cmds.rowLayout(numberOfColumns=2,
                       columnWidth2=[200, 48],
                       columnAlign2=['left', 'left'])
        Cmds.text(label=' To trim the keys between moves')
        Cmds.button(label='Trim Keys',
                    c=Partial(self.TrimKeys),
                    backgroundColor=[0.8, 0.9, 0.8])

        self.CreateSeparator('..')

        self.Refresh()