예제 #1
0
    def __init__(self, config, logger=None, test=False):

        self.logger = logger
        self.step = dict()
        self.config = config

        # setup models and optimzers
        self.handlers = dict()
        self.optims = dict()
        for (key, cfg) in self.config['module'].items():
            if 'ckpt' in cfg:
                ckpt = cfg['ckpt']
            else:
                ckpt = None
            self.handlers[key] = ModelHandler(
                cfg['config'],
                checkpoint=ckpt,
            )
            self.optims[key] = Optimizer(cfg['optim'])(
                self.handlers[key].model)
            self.optims[key].zero_grad()

        self.metrics = {
            key: MetricFlow(config)
            for (key, config) in self.config['metric'].items()
        }

        # setup data generators
        self.generators = dict()
        for (key, cfg) in self.config['generator'].items():
            with open(cfg['data']) as f:
                data_config = json5.load(f)
            data_list = data_config['list']
            if test:
                data_list = data_list[:1]
            loader_config = data_config['loader']
            loader_name = loader_config.pop('name')
            data_loader = DataLoader(loader_name, **loader_config)
            data_loader.set_data_list(data_list)
            self.generators[key] = DataGenerator(data_loader, cfg['struct'])
예제 #2
0
timer = time.time()
start = timer

# load config
with open(args.config) as f:
    config = json5.load(f)

# build up the data generator
with open(config['generator']['data']) as f:
    data_config = json5.load(f)
data_list = data_config['list']
if args.test:
    data_list = data_list[:1]
loader_config = data_config['loader']
loader_name = loader_config.pop('name')
data_loader = DataLoader(loader_name, **loader_config)
data_loader.set_data_list(data_list)
data_gen = DataGenerator(data_loader, config['generator']['struct'])

# build up the reverter
reverter = Reverter(data_gen)
DL = data_gen.struct['DL']
PG = data_gen.struct['PG']
BG = data_gen.struct['BG']
# ensure the order
if PG.n_workers > 1:
    assert PG.ordered
assert BG.n_workers == 1
if 'AG' in data_gen.struct:
    assert data_gen.struct['AG'].n_workers == 1
예제 #3
0
parser = argparse.ArgumentParser()
parser.add_argument('--generator-config',
                    required=True,
                    help='generator config')
parser.add_argument('--loader-config', required=True, help='loader config')
parser.add_argument('--output-dir',
                    default='outputs',
                    help='directory to store ouptut images')
args = parser.parse_args()

timer = time.time()
with open(args.loader_config) as f:
    loader_config = yaml.safe_load(f)
loader_name = loader_config.pop('name')
data_loader = DataLoader(loader_name, **loader_config)

with open(args.generator_config) as f:
    generator_config = yaml.safe_load(f)
data_gen = DataGenerator(data_loader, generator_config)

DL = data_gen.struct['DL']
PG = data_gen.struct['PG']
BG = data_gen.struct['BG']
# ensure the order
if PG.n_workers > 1:
    assert PG.ordered
assert BG.n_workers == 1
if 'AG' in data_gen.struct:
    assert data_gen.struct['AG'].n_workers == 1
예제 #4
0
import argparse
import yaml

parser = argparse.ArgumentParser()
parser.add_argument(
    '--loader-config',
    required=True,
    help='loader config'
)
args = parser.parse_args()

with open(args.loader_config) as f:
    loader_config = yaml.safe_load(f)

loader_name = loader_config.pop('name')
data_loader = DataLoader(loader_name, **loader_config)

assert len(data_loader.data_list) > 0
idx = data_loader.data_list[0]

images = [
    data_loader.get_image(idx),
    data_loader.get_label(idx)
]

idx_containing_ROI = np.where(images[1] != 0)[2][0]

fig, axes = plt.subplots(
    1,
    len(images),
    constrained_layout=True,
예제 #5
0
    def __init__(self, config, logger=None, test=False):

        self.logger = logger
        self.step = dict()
        self.config = config

        # task variables
        self.running_task = None
        self.tasks = config['task']
        for (task_name, task_config) in self.tasks.items():
            self.tasks[task_name]['need_backward'] = any(
                list(task_config['toggle'].values()))

        # setup models and optimizers
        self.handlers = dict()
        self.optims = dict()
        self.lr_schedulers = dict()
        self.weight_clip = dict()
        for (key, cfg) in self.config['module'].items():
            if 'ckpt' in cfg:
                ckpt = cfg['ckpt']
            else:
                ckpt = None
            self.handlers[key] = ModelHandler(
                cfg['config'],
                checkpoint=ckpt,
            )
            if 'weight_clip' in cfg:
                assert isinstance(cfg['weight_clip'], (tuple, list))
                assert len(cfg['weight_clip']) == 2
                assert cfg['weight_clip'][0] < cfg['weight_clip'][1]
                print('Weight clip {} on the model {}'.format(
                    cfg['weight_clip'], key))
                self.weight_clip[key] = cfg['weight_clip']

            self.optims[key] = Optimizer(cfg['optim'])(
                self.handlers[key].model)
            self.optims[key].zero_grad()
            if 'lr_scheduler' in cfg:
                self.lr_schedulers[key] = getattr(
                    torch.optim.lr_scheduler,
                    cfg['lr_scheduler'].pop('name'),
                )(self.optims[key], **cfg['lr_scheduler'])

        self.metrics = {
            key: MetricFlow(config)
            for (key, config) in self.config['metric'].items()
        }

        # setup data generators
        self.generators = dict()
        for (key, cfg) in self.config['generator'].items():
            with open(cfg['data']) as f:
                data_config = json5.load(f)
            data_list = data_config['list']
            if test:
                data_list = data_list[:1]
            loader_config = data_config['loader']
            loader_name = loader_config.pop('name')
            data_loader = DataLoader(loader_name, **loader_config)
            data_loader.set_data_list(data_list)
            self.generators[key] = DataGenerator(data_loader, cfg['struct'])
예제 #6
0
# SSL
# assert 'train_ssl' in stages
if 'train_ssl' in stages:
    print('SSL is included in the training.')

with open(config['data']) as f:
    data_config = yaml.safe_load(f)
data_list = data_config['list']
loader_config = data_config['loader']

# - data pipeline
data_gen = dict()
loader_name = loader_config.pop('name')
ROIs = None
for stage in stages:
    data_loader = DataLoader(loader_name, **loader_config)
    if stage == 'train_ssl' and stage not in data_list:
        data_loader.set_data_list(data_list['valid'])
    else:
        assert stage in data_list
        data_loader.set_data_list(data_list[stage])
    data_gen[stage] = DataGenerator(data_loader, generator_config[stage])

    if ROIs is None:
        ROIs = data_loader.ROIs

# FIXME
reverter = Reverter(data_gen['valid'])

# - GPUs
os.environ['CUDA_VISIBLE_DEVICES'] = str(config['gpus'])
예제 #7
0
import yaml
import json

parser = argparse.ArgumentParser()
parser.add_argument('--config', required=True, help='config')
parser.add_argument('--output', default='box.json', help='output bbox')
args = parser.parse_args()

# load config
with open(args.config) as f:
    config = yaml.safe_load(f)
data_list = config['list']
loader_config = config['loader']

loader_name = loader_config.pop('name')
data_loader = DataLoader(loader_name, **loader_config)
if data_list is not None:
    data_loader.set_data_list(data_list)

box = {
    'corner1': np.ones(3) * np.Inf,
    'corner2': np.zeros(3),
}
for data_idx in tqdm(data_loader.data_list):
    indices = np.where(data_loader.get_label(data_idx) > 0)
    corner1 = np.array([min(idx) for idx in indices])
    corner2 = np.array([max(idx) for idx in indices])
    box['corner1'] = np.minimum(box['corner1'], corner1)
    box['corner2'] = np.maximum(box['corner2'], corner2)
box['center'] = (box['corner2'] + box['corner1']) / 2
box['size'] = box['corner2'] - box['corner1']
예제 #8
0
#!/usr/bin/env python3

from MIDP import DataLoader
import argparse
import yaml
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument('--loader-config', required=True, help='loader config')
parser.add_argument('--output-dir', default='output', help='output dir')
args = parser.parse_args()

with open(args.loader_config) as f:
    loader_config = yaml.safe_load(f)

loader_name = loader_config.pop('name')
data_loader = DataLoader(loader_name, **loader_config)

for data_idx in tqdm(data_loader.data_list):
    label = data_loader.get_label(data_idx)
    data_loader.save_prediction(data_idx, label, args.output_dir)