Beispiel #1
0
    def __init__(self, config, logger=None, test=False):

        self.logger = logger
        self.step = dict()
        self.config = config

        # setup models and optimzers
        self.handlers = dict()
        self.optims = dict()
        for (key, cfg) in self.config['module'].items():
            if 'ckpt' in cfg:
                ckpt = cfg['ckpt']
            else:
                ckpt = None
            self.handlers[key] = ModelHandler(
                cfg['config'],
                checkpoint=ckpt,
            )
            self.optims[key] = Optimizer(cfg['optim'])(
                self.handlers[key].model)
            self.optims[key].zero_grad()

        self.metrics = {
            key: MetricFlow(config)
            for (key, config) in self.config['metric'].items()
        }

        # setup data generators
        self.generators = dict()
        for (key, cfg) in self.config['generator'].items():
            with open(cfg['data']) as f:
                data_config = json5.load(f)
            data_list = data_config['list']
            if test:
                data_list = data_list[:1]
            loader_config = data_config['loader']
            loader_name = loader_config.pop('name')
            data_loader = DataLoader(loader_name, **loader_config)
            data_loader.set_data_list(data_list)
            self.generators[key] = DataGenerator(data_loader, cfg['struct'])
Beispiel #2
0
# load config
with open(args.config) as f:
    config = json5.load(f)

# build up the data generator
with open(config['generator']['data']) as f:
    data_config = json5.load(f)
data_list = data_config['list']
if args.test:
    data_list = data_list[:1]
loader_config = data_config['loader']
loader_name = loader_config.pop('name')
data_loader = DataLoader(loader_name, **loader_config)
data_loader.set_data_list(data_list)
data_gen = DataGenerator(data_loader, config['generator']['struct'])

# build up the reverter
reverter = Reverter(data_gen)
DL = data_gen.struct['DL']
PG = data_gen.struct['PG']
BG = data_gen.struct['BG']
# ensure the order
if PG.n_workers > 1:
    assert PG.ordered
assert BG.n_workers == 1
if 'AG' in data_gen.struct:
    assert data_gen.struct['AG'].n_workers == 1

# - GPUs
if 'gpus' in config:
Beispiel #3
0
    config = yaml.safe_load(f)
generator_config = config['generator']
with open(config['data']) as f:
    data_config = yaml.safe_load(f)
data_list = data_config['list']
loader_config = data_config['loader']

# - data pipeline
data_gen = dict()
loader_name = loader_config.pop('name')
ROIs = None
for stage in ['train', 'valid']:
    data_loader = DataLoader(loader_name, **loader_config)
    if data_list[stage] is not None:
        data_loader.set_data_list(data_list[stage])
    data_gen[stage] = DataGenerator(data_loader, generator_config[stage])

    if ROIs is None:
        ROIs = data_loader.ROIs

# - GPUs
os.environ['CUDA_VISIBLE_DEVICES'] = str(config['gpus'])

# - model
model_name = config['model'].pop('name')
model = getattr(models, model_name)(**config['model'])

# - optimizer
optimizer = Optimizer(config['optimizer'])(model)

# - scheduler
Beispiel #4
0
                    help='generator config')
parser.add_argument('--loader-config', required=True, help='loader config')
parser.add_argument('--output-dir',
                    default='outputs',
                    help='directory to store ouptut images')
args = parser.parse_args()

timer = time.time()
with open(args.loader_config) as f:
    loader_config = yaml.safe_load(f)
loader_name = loader_config.pop('name')
data_loader = DataLoader(loader_name, **loader_config)

with open(args.generator_config) as f:
    generator_config = yaml.safe_load(f)
data_gen = DataGenerator(data_loader, generator_config)

DL = data_gen.struct['DL']
PG = data_gen.struct['PG']
BG = data_gen.struct['BG']
# ensure the order
if PG.n_workers > 1:
    assert PG.ordered
assert BG.n_workers == 1
if 'AG' in data_gen.struct:
    assert data_gen.struct['AG'].n_workers == 1

os.makedirs(args.output_dir, exist_ok=True)
queue = []
scores = dict()
Beispiel #5
0
    def __init__(self, config, logger=None, test=False):

        self.logger = logger
        self.step = dict()
        self.config = config

        # task variables
        self.running_task = None
        self.tasks = config['task']
        for (task_name, task_config) in self.tasks.items():
            self.tasks[task_name]['need_backward'] = any(
                list(task_config['toggle'].values()))

        # setup models and optimizers
        self.handlers = dict()
        self.optims = dict()
        self.lr_schedulers = dict()
        self.weight_clip = dict()
        for (key, cfg) in self.config['module'].items():
            if 'ckpt' in cfg:
                ckpt = cfg['ckpt']
            else:
                ckpt = None
            self.handlers[key] = ModelHandler(
                cfg['config'],
                checkpoint=ckpt,
            )
            if 'weight_clip' in cfg:
                assert isinstance(cfg['weight_clip'], (tuple, list))
                assert len(cfg['weight_clip']) == 2
                assert cfg['weight_clip'][0] < cfg['weight_clip'][1]
                print('Weight clip {} on the model {}'.format(
                    cfg['weight_clip'], key))
                self.weight_clip[key] = cfg['weight_clip']

            self.optims[key] = Optimizer(cfg['optim'])(
                self.handlers[key].model)
            self.optims[key].zero_grad()
            if 'lr_scheduler' in cfg:
                self.lr_schedulers[key] = getattr(
                    torch.optim.lr_scheduler,
                    cfg['lr_scheduler'].pop('name'),
                )(self.optims[key], **cfg['lr_scheduler'])

        self.metrics = {
            key: MetricFlow(config)
            for (key, config) in self.config['metric'].items()
        }

        # setup data generators
        self.generators = dict()
        for (key, cfg) in self.config['generator'].items():
            with open(cfg['data']) as f:
                data_config = json5.load(f)
            data_list = data_config['list']
            if test:
                data_list = data_list[:1]
            loader_config = data_config['loader']
            loader_name = loader_config.pop('name')
            data_loader = DataLoader(loader_name, **loader_config)
            data_loader.set_data_list(data_list)
            self.generators[key] = DataGenerator(data_loader, cfg['struct'])