Beispiel #1
0
def make_vgg(cfg_id,
             batch_norm,
             pretrained,
             url_id,
             incl_fcs=True,
             num_classes=1000):
    model = VGG(cfgs[cfg_id],
                batch_norm=batch_norm,
                incl_fcs=incl_fcs,
                num_classes=num_classes)
    if pretrained:
        pretrained_model = model_zoo.load_url(model_urls[url_id])
        for k, v in fc_mapping.items():
            pretrained_model[k] = pretrained_model.pop(v)
        if num_classes != 1000:
            del pretrained_model['classifier.weight']
            del pretrained_model['classifier.bias']

        try:
            load_state_dict(model, pretrained_model)
        except KeyError:
            from jacinle.logging import get_logger
            logger = get_logger(__file__)
            logger.exception('test')
            pass  # Intentionally ignore the key error.
    return model
Beispiel #2
0
def reset_global_seed(seed=None, verbose=False):
    if seed is None:
        seed = gen_seed()
    for k, seed_getter in global_rng_registry.items():
        if verbose:
            from jacinle.logging import get_logger
            logger = get_logger(__file__)
            logger.critical('Reset random seed for: {} (pid={}, seed={}).'.format(k, os.getpid(), seed))
        seed_getter()(seed)
Beispiel #3
0
def git_guard(force=False):
    uncommitted_files = git_uncommitted_files()
    if len(uncommitted_files) > 0:
        from jacinle.logging import get_logger
        from jacinle.cli.keyboard import yes_or_no
        logger = get_logger(__file__)

        logger.warning('Uncommited changes at the current repo:\n  ' +
                       '\n  '.join(uncommitted_files))
        if force:
            if not yes_or_no('Are you sure you want to continue?',
                             default='no'):
                exit(1)
        logger.info(git_status_full())
Beispiel #4
0
    def __init__(self,
                 counters=None,
                 display_names=None,
                 interval=1,
                 printf=None):
        if counters is None:
            counters = ['DEFAULT']

        self._display_names = display_names
        self._counters = collections.OrderedDict([(n, Counter())
                                                  for n in counters])
        self._interval = interval
        self._printf = printf

        if self._printf is None:
            from jacinle.logging import get_logger
            logger = get_logger(__file__)
            self._printf = logger.info
Beispiel #5
0
import torch

from nscl.datasets.definition import gdef
from nscl.datasets.common.filterable import (
    FilterableDatasetUnwrapped,
    FilterableDatasetView,
)
from nscl.datasets.common.vocab import Vocab
from nscl.datasets.common.program_translator import (
    nsclseq_to_nscltree,
    nsclseq_to_nsclqsseq,
    nscltree_to_nsclqstree,
    gen_vocab,
)

logger = get_logger(__file__)

__all__ = [
    "NSCLDataset", "ConceptRetrievalDataset", "ConceptQuantizationDataset"
]


class NSCLDatasetUnwrapped(FilterableDatasetUnwrapped):
    def __init__(
        self,
        scenes_json,
        questions_json,
        image_root,
        image_transform,
        vocab_json,
        question_transform=None,