Example #1
0
    def __init__(self,
                 session: Session,
                 task: Task,
                 layout: str,
                 part: str = 'valid',
                 name: str = 'img_segment',
                 max_img_size: Tuple[int, int] = None,
                 stack_type: str = 'vertical',
                 main_metric: str = 'dice',
                 plot_count: int = 0,
                 colors: List[Tuple] = None):
        self.session = session
        self.task = task
        self.layout = layout
        self.part = part
        self.name = name or 'img_segment'
        self.max_img_size = max_img_size
        self.stack_type = stack_type
        self.main_metric = main_metric
        self.colors = colors
        self.plot_count = plot_count

        self.dag_provider = DagProvider(session)
        self.report_provider = ReportProvider(session)
        self.layout_provider = ReportLayoutProvider(session)
        self.task_provider = TaskProvider(session)
        self.report_img_provider = ReportImgProvider(session)
        self.report_task_provider = ReportTasksProvider(session)
        self.report_series_provider = ReportSeriesProvider(session)

        self.project = self.task_provider.project(task.id).id
        self.layout = self.layout_provider.by_name(layout)
        self.layout_dict = yaml_load(self.layout.content)

        self.create_base()
Example #2
0
    def __init__(
            self,
            args: Args,
            report: ReportLayoutInfo,
            distr_info: dict,
            resume: dict,
            grid_config: dict,
            trace: str,
            params: dict,
            **kwargs
    ):
        super().__init__(**kwargs)

        self.order = 0
        self.resume = resume
        self.distr_info = distr_info
        self.args = args
        self.report = report
        self.experiment = None
        self.runner = None
        self.series_provider = ReportSeriesProvider(self.session)
        self.computer_provider = ComputerProvider(self.session)
        self.grid_config = grid_config
        self.master = True
        self.checkpoint_resume = False
        self.checkpoint_stage_epoch = 0
        self.trace = trace
        self.params = params
        self.last_batch_logged = None
        self.loader_started_time = None
        self.parent = None
        self.loader_step_start = 0
Example #3
0
    def __init__(
            self,
            args: Args,
            report: ReportLayoutInfo,
            distr_info: dict,
            resume: dict,
            grid_config: dict,
            trace: str,
            params: dict
    ):
        super().__init__(order=0)

        self.resume = resume
        self.distr_info = distr_info
        self.args = args
        self.report = report
        self.experiment = None
        self.runner = None
        self.series_provider = ReportSeriesProvider(self.session)
        self.grid_config = grid_config
        self.master = True
        self.checkpoint_resume = False
        self.checkpoint_stage_epoch = 0
        self.trace = trace
        self.params = params
Example #4
0
    def __init__(self,
                 session: Session,
                 task: Task,
                 layout: str,
                 part: str = 'valid',
                 name: str = 'img_classify',
                 max_img_size: Tuple[int, int] = None,
                 main_metric: str = 'accuracy',
                 plot_count: int = 0):
        self.session = session
        self.task = task
        self.layout = layout
        self.part = part
        self.name = name or 'img_classify'
        self.max_img_size = max_img_size
        self.main_metric = main_metric
        self.plot_count = plot_count

        self.dag_provider = DagProvider(session)
        self.report_provider = ReportProvider(session)
        self.layout_provider = ReportLayoutProvider(session)
        self.task_provider = TaskProvider(session)
        self.report_img_provider = ReportImgProvider(session)
        self.report_task_provider = ReportTasksProvider(session)
        self.report_series_provider = ReportSeriesProvider(session)

        self.project = self.task_provider.project(task.id).id
        self.layout = self.layout_provider.by_name(layout)
        self.layout_dict = yaml_load(self.layout.content)
Example #5
0
def describe_metrics(dag: int, metrics: List[str], axis, last_n_epoch=None):
    metrics = metrics or []

    series_provider = ReportSeriesProvider()
    series = series_provider.by_dag(dag, metrics)

    for i in range(len(axis)):
        ax = axis[i]
        if i >= len(series):
            ax.axis('off')
            continue

        ax.axis('on')
        task_name, metric, groups = series[i]

        for group in groups:
            if last_n_epoch:
                group['epoch'] = group['epoch'][-last_n_epoch:]
                group['value'] = group['value'][-last_n_epoch:]

            ax.plot(group['epoch'], group['value'], label=group['name'])

        ax.set_title(f'{task_name}, {metric} score')
        ax.set_ylabel(metric, labelpad=20)
        ax.set_xlabel('epoch')
        ax.xaxis.set_major_locator(MaxNLocator(integer=True))
        ax.legend()
Example #6
0
class SegmentationReportBuilder:
    def __init__(self,
                 session: Session,
                 task: Task,
                 layout: str,
                 part: str = 'valid',
                 name: str = 'img_segment',
                 max_img_size: Tuple[int, int] = None,
                 stack_type: str = 'vertical',
                 main_metric: str = 'dice',
                 plot_count: int = 0,
                 colors: List[Tuple] = None):
        self.session = session
        self.task = task
        self.layout = layout
        self.part = part
        self.name = name or 'img_segment'
        self.max_img_size = max_img_size
        self.stack_type = stack_type
        self.main_metric = main_metric
        self.colors = colors
        self.plot_count = plot_count

        self.dag_provider = DagProvider(session)
        self.report_provider = ReportProvider(session)
        self.layout_provider = ReportLayoutProvider(session)
        self.task_provider = TaskProvider(session)
        self.report_img_provider = ReportImgProvider(session)
        self.report_task_provider = ReportTasksProvider(session)
        self.report_series_provider = ReportSeriesProvider(session)

        self.project = self.task_provider.project(task.id).id
        self.layout = self.layout_provider.by_name(layout)
        self.layout_dict = yaml_load(self.layout.content)

        self.create_base()

    def create_base(self):
        report = Report(config=yaml_dump(self.layout_dict),
                        time=now(),
                        layout=self.layout.name,
                        project=self.project,
                        name=self.name)
        self.report_provider.add(report)
        self.report_task_provider.add(
            ReportTasks(report=report.id, task=self.task.id))

        self.task.report = report.id
        self.task_provider.update()

    def encode_pred(self, mask: np.array):
        res = np.zeros((*mask.shape[1:], 3), dtype=np.uint8)
        for i, c in enumerate(mask):
            c = np.repeat(c[:, :, None], 3, axis=2)
            color = self.colors[i] if self.colors is not None else (255, 255,
                                                                    255)
            res += (c * color).astype(np.uint8)

        return res

    def plot_mask(self, img: np.array, mask: np.array):
        if len(img.shape) == 2:
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
        img = img.astype(np.uint8)
        mask = mask.astype(np.uint8)

        for i, c in enumerate(mask):
            contours, _ = cv2.findContours(c, cv2.RETR_LIST,
                                           cv2.CHAIN_APPROX_NONE)
            color = self.colors[i] if self.colors else (0, 255, 0)
            for i in range(0, len(contours)):
                cv2.polylines(img, contours[i], True, color, 2)

        return img

    def process_scores(self, scores):
        for key, item in self.layout_dict['items'].items():
            item['name'] = key
            if item['type'] == 'series' and item['key'] in scores:
                series = ReportSeries(name=item['name'],
                                      value=scores[item['key']],
                                      epoch=0,
                                      time=now(),
                                      task=self.task.id,
                                      part='valid',
                                      stage='stage1')

                self.report_series_provider.add(series)

    def process_pred(self,
                     imgs: np.array,
                     preds: dict,
                     targets: np.array = None,
                     attrs=None,
                     scores=None):
        for key, item in self.layout_dict['items'].items():
            item['name'] = key
            if item['type'] != 'img_segment':
                continue

            report_imgs = []
            dag = self.dag_provider.by_id(self.task.dag)

            for i in range(len(imgs)):
                if self.plot_count <= 0:
                    break

                if targets is not None:
                    img = self.plot_mask(imgs[i], targets[i])
                else:
                    img = imgs[i]

                imgs_add = [img]
                for key, value in preds.items():
                    imgs_add.append(self.encode_pred(value[i]))

                for j in range(len(imgs_add)):
                    imgs_add[j] = resize_saving_ratio(imgs_add[j],
                                                      self.max_img_size)

                if self.stack_type == 'horizontal':
                    img = np.hstack(imgs_add)
                else:
                    img = np.vstack(imgs_add)

                attr = attrs[i] if attrs else {}

                score = None
                if targets is not None:
                    score = scores[self.main_metric][i]

                retval, buffer = cv2.imencode('.jpg', img)
                report_img = ReportImg(group=item['name'],
                                       epoch=0,
                                       task=self.task.id,
                                       img=buffer,
                                       dag=self.task.dag,
                                       part=self.part,
                                       project=self.project,
                                       score=score,
                                       **attr)

                self.plot_count -= 1
                report_imgs.append(report_img)
                dag.img_size += report_img.size

            self.dag_provider.commit()
            self.report_img_provider.bulk_save_objects(report_imgs)
Example #7
0
class Catalyst(Executor, Callback):
    def __init__(self, args: Args, report: ReportLayoutInfo, distr_info: dict,
                 resume: dict, grid_config: dict, trace: str, params: dict):
        super().__init__(order=0)

        self.resume = resume
        self.distr_info = distr_info
        self.args = args
        self.report = report
        self.experiment = None
        self.runner = None
        self.series_provider = ReportSeriesProvider(self.session)
        self.computer_provider = ComputerProvider(self.session)
        self.grid_config = grid_config
        self.master = True
        self.checkpoint_resume = False
        self.checkpoint_stage_epoch = 0
        self.trace = trace
        self.params = params

    def callbacks(self):
        result = OrderedDict()
        if self.master:
            result['catalyst'] = self

        return result

    def on_epoch_start(self, state: RunnerState):
        if self.checkpoint_resume and state.stage_epoch == 0:
            state.epoch += 1

        state.stage_epoch = state.stage_epoch + self.checkpoint_stage_epoch
        state.checkpoint_data = {'stage_epoch': state.stage_epoch}
        if self.master:
            if state.stage_epoch == 0:
                self.step.start(1, name=state.stage)

            self.step.start(2,
                            name=f'epoch {state.stage_epoch}',
                            index=state.stage_epoch)

    def on_epoch_end(self, state: RunnerState):
        self.step.end(2)

        for s in self.report.series:
            train = state.metrics.epoch_values['train'][s.key]
            val = state.metrics.epoch_values['valid'][s.key]

            task_id = self.task.parent or self.task.id
            train = ReportSeries(part='train',
                                 name=s.key,
                                 epoch=state.epoch,
                                 task=task_id,
                                 value=train,
                                 time=now(),
                                 stage=state.stage)

            val = ReportSeries(part='valid',
                               name=s.key,
                               epoch=state.epoch,
                               task=task_id,
                               value=val,
                               time=now(),
                               stage=state.stage)

            self.series_provider.add(train)
            self.series_provider.add(val)

            if s.key == self.report.metric.name:
                best = False
                task = self.task
                if task.parent:
                    task = self.task_provider.by_id(task.parent)

                if self.report.metric.minimize:
                    if task.score is None or val.value < task.score:
                        best = True
                else:
                    if task.score is None or val.value > task.score:
                        best = True
                if best:
                    task.score = val.value
                    self.task_provider.update()

    def on_stage_start(self, state: RunnerState):
        state.loggers = {
            'console': VerboseLogger(),
            'raise': RaiseExceptionLogger()
        }

    def on_stage_end(self, state: RunnerState):
        self.checkpoint_resume = False
        self.checkpoint_stage_epoch = 0
        self.step.end(1)

    @classmethod
    def _from_config(cls, executor: dict, config: Config,
                     additional_info: dict):
        args = Args()
        for k, v in executor['args'].items():
            v = str(v)
            if v in ['False', 'True']:
                v = v == 'True'
            elif v.isnumeric():
                v = int(v)

            setattr(args, k, v)

        assert 'report_config' in additional_info, 'layout was not filled'
        report_config = additional_info['report_config']
        grid_cell = additional_info.get('grid_cell')
        report = ReportLayoutInfo(report_config)
        if len(args.configs) == 0:
            args.configs = [args.config]

        grid_config = {}
        if grid_cell is not None:
            grid_config = grid_cells(executor['grid'])[grid_cell][0]

        distr_info = additional_info.get('distr_info', {})
        resume = additional_info.get('resume')
        params = executor.get('params', {})

        return cls(args=args,
                   report=report,
                   grid_config=grid_config,
                   distr_info=distr_info,
                   resume=resume,
                   trace=executor.get('trace'),
                   params=params)

    def set_dist_env(self, config):
        info = self.distr_info
        os.environ['MASTER_ADDR'] = info['master_addr']
        os.environ['MASTER_PORT'] = str(info['master_port'])
        os.environ['WORLD_SIZE'] = str(info['world_size'])

        os.environ['RANK'] = str(info['rank'])
        distributed_params = config.get('distributed_params', {})
        distributed_params['rank'] = 0
        config['distributed_params'] = distributed_params

        if info['rank'] > 0:
            self.master = False

    def parse_args_uargs(self):
        args, config = parse_args_uargs(self.args, [])
        config = merge_dicts_smart(config, self.grid_config)
        config = merge_dicts_smart(config, self.params)

        if self.distr_info:
            self.set_dist_env(config)
        return args, config

    def _checkpoint_fix_config(self, experiment):
        resume = self.resume
        if not resume:
            return

        checkpoint_dir = join(experiment.logdir, 'checkpoints')
        os.makedirs(checkpoint_dir, exist_ok=True)

        file = 'last_full.pth' if resume.get('load_last') else 'best_full.pth'

        path = join(checkpoint_dir, file)
        computer = socket.gethostname()
        if computer != resume['master_computer']:
            master_computer = self.computer_provider.by_name(
                resume['master_computer'])
            path_from = join(master_computer.root_folder,
                             str(resume['master_task_id']), 'log',
                             'checkpoints', file)
            self.info(f'copying checkpoint from: computer = '
                      f'{resume["master_computer"]} path_from={path_from} '
                      f'path_to={path}')

            success = copy_remote(session=self.session,
                                  computer_from=resume['master_computer'],
                                  path_from=path_from,
                                  path_to=path)

            if not success:
                self.error(f'copying from '
                           f'{resume["master_computer"]}/'
                           f'{path_from} failed')
            else:
                self.info('checkpoint copied successfully')

        elif self.task.id != resume['master_task_id']:
            path = join(TASK_FOLDER, str(resume['master_task_id']), 'log',
                        'checkpoints', file)
            self.info(f'master_task_id!=task.id, using checkpoint'
                      f' from task_id = {resume["master_task_id"]}')

        if not os.path.exists(path):
            self.info(f'no checkpoint at {path}')
            return

        ckpt = load_checkpoint(path)
        stages_config = experiment.stages_config
        for k, v in list(stages_config.items()):
            if k == ckpt['stage']:
                stage_epoch = ckpt['checkpoint_data']['stage_epoch'] + 1

                # if it is the last epoch in the stage
                if stage_epoch == v['state_params']['num_epochs'] \
                        or resume.get('load_best'):
                    del stages_config[k]
                    break

                self.checkpoint_stage_epoch = stage_epoch
                v['state_params']['num_epochs'] -= stage_epoch
                break
            del stages_config[k]

        stage = experiment.stages_config[experiment.stages[0]]
        for k, v in stage['callbacks_params'].items():
            if v.get('callback') == 'CheckpointCallback':
                v['resume'] = path

        self.info(f'found checkpoint at {path}')

    def _checkpoint_fix_callback(self, callbacks: dict):
        def mock(state):
            pass

        for k, c in callbacks.items():
            if not isinstance(c, CheckpointCallback):
                continue

            if c.resume:
                self.checkpoint_resume = True

            if not self.master:
                c.on_epoch_end = mock
                c.on_stage_end = mock

    def work(self):
        args, config = self.parse_args_uargs()
        set_global_seed(args.seed)

        Experiment, R = import_experiment_and_runner(Path(args.expdir))

        runner_params = config.pop('runner_params', {})

        experiment = Experiment(config)
        runner: Runner = R(**runner_params)

        register()

        self.experiment = experiment
        self.runner = runner

        stages = experiment.stages[:]

        if self.master:
            task = self.task if not self.task.parent \
                else self.task_provider.by_id(self.task.parent)
            task.steps = len(stages)
            self.task_provider.commit()

        self._checkpoint_fix_config(experiment)

        _get_callbacks = experiment.get_callbacks

        def get_callbacks(stage):
            res = self.callbacks()
            for k, v in _get_callbacks(stage).items():
                res[k] = v

            self._checkpoint_fix_callback(res)
            return res

        experiment.get_callbacks = get_callbacks

        if experiment.logdir is not None:
            dump_environment(config, experiment.logdir, args.configs)

        if self.distr_info:
            info = yaml_load(self.task.additional_info)
            info['resume'] = {
                'master_computer': self.distr_info['master_computer'],
                'master_task_id': self.task.id - self.distr_info['rank'],
                'load_best': True
            }
            self.task.additional_info = yaml_dump(info)
            self.task_provider.commit()

            experiment.stages_config = {
                k: v
                for k, v in experiment.stages_config.items()
                if k == experiment.stages[0]
            }

        runner.run_experiment(experiment, check=args.check)

        if self.master and self.trace:
            traced = trace_model_from_checkpoint(self.experiment.logdir, self)
            torch.jit.save(traced, self.trace)

        return {'stage': experiment.stages[-1], 'stages': stages}
Example #8
0
class Catalyst(Executor, Callback):
    def __init__(self, args: Args, report: ReportLayoutInfo, distr_info: dict,
                 resume: dict, grid_config: dict, trace: str, params: dict,
                 **kwargs):
        super().__init__(**kwargs)

        self.series_provider = ReportSeriesProvider(self.session)
        self.computer_provider = ComputerProvider(self.session)
        self.memory_provider = MemoryProvider(self.session)

        self.order = 0
        self.resume = resume
        self.distr_info = distr_info
        self.args = args
        self.report = report
        self.experiment = None
        self.runner = None
        self.grid_config = grid_config
        self.master = True
        self.trace = trace
        self.params = params
        self.last_batch_logged = None
        self.loader_started_time = None
        self.parent = None
        self.node = CallbackNode.All

    def get_parent_task(self):
        if self.parent:
            return self.parent
        return self.task

    def callbacks(self):
        result = OrderedDict()
        if self.master:
            result['catalyst'] = self

        return result

    def on_loader_start(self, state: State):
        self.loader_started_time = now()

    def on_epoch_start(self, state: State):
        stage_index = self.experiment.stages.index(state.stage_name)
        self.step.start(1, name=state.stage_name, index=stage_index)

        self.step.start(2, name=f'epoch {state.epoch}', index=state.epoch - 1)

    def on_batch_start(self, state: State):
        if self.last_batch_logged and state.loader_step != state.loader_len:
            if (now() - self.last_batch_logged).total_seconds() < 10:
                return

        task = self.get_parent_task()
        task.batch_index = state.loader_step
        task.batch_total = state.loader_len
        task.loader_name = state.loader_name

        duration = int((now() - self.loader_started_time).total_seconds())
        task.epoch_duration = duration
        task.epoch_time_remaining = int(
            duration *
            (task.batch_total / task.batch_index)) - task.epoch_duration
        if state.epoch_metrics.get('train_loss') is not None:
            task.loss = float(state.epoch_metrics['train_loss'])
        if state.epoch_metrics.get('valid_loss') is not None:
            task.loss = float(state.epoch_metrics['valid_loss'])

        self.task_provider.update()
        self.last_batch_logged = now()

    def on_epoch_end(self, state: State):
        self.step.end(2)

        values = state.epoch_metrics

        for k, v in values.items():
            part = ''
            name = k

            for loader in state.loaders:
                if k.startswith(loader):
                    part = loader
                    name = k.replace(loader, '')
                    if name.startswith('_'):
                        name = name[1:]

            task_id = self.task.parent or self.task.id
            series = ReportSeries(part=part,
                                  name=name,
                                  epoch=state.epoch - 1,
                                  task=task_id,
                                  value=v,
                                  time=now(),
                                  stage=state.stage_name)
            self.series_provider.add(series)

            if name == self.report.metric.name:
                best = False
                task = self.task
                if task.parent:
                    task = self.task_provider.by_id(task.parent)

                if self.report.metric.minimize:
                    if task.score is None or v < task.score:
                        best = True
                else:
                    if task.score is None or v > task.score:
                        best = True
                if best:
                    task.score = v
                    self.task_provider.update()

    def on_stage_end(self, state: State):
        self.step.end(1)

    @classmethod
    def _from_config(cls, executor: dict, config: Config,
                     additional_info: dict):
        args = Args()
        for k, v in executor['args'].items():
            v = str(v)
            if v in ['False', 'True']:
                v = v == 'True'
            elif v.isnumeric():
                v = int(v)

            setattr(args, k, v)

        assert 'report_config' in additional_info, 'layout was not filled'
        report_config = additional_info['report_config']
        report = ReportLayoutInfo(report_config)
        if len(args.configs) == 0:
            args.configs = [args.config]

        distr_info = additional_info.get('distr_info', {})
        resume = additional_info.get('resume')
        params = executor.get('params', {})
        params.update(additional_info.get('params', {}))

        grid_config = executor.copy()
        grid_config.pop('args', '')

        return cls(args=args,
                   report=report,
                   grid_config=grid_config,
                   distr_info=distr_info,
                   resume=resume,
                   trace=executor.get('trace'),
                   params=params)

    def set_dist_env(self, config):
        info = self.distr_info
        os.environ['MASTER_ADDR'] = info['master_addr']
        os.environ['MASTER_PORT'] = str(info['master_port'])
        os.environ['WORLD_SIZE'] = str(info['world_size'])

        os.environ['RANK'] = str(info['rank'])
        os.environ['LOCAL_RANK'] = "0"
        distributed_params = config.get('distributed_params', {})
        distributed_params['rank'] = info['rank']
        config['distributed_params'] = distributed_params

        torch.cuda.set_device(0)

        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")

        if info['rank'] > 0:
            self.master = False
            self.node = CallbackNode.Worker
        else:
            self.node = CallbackNode.Master

    def parse_args_uargs(self):
        args, config = parse_args_uargs(self.args, [])
        config = merge_dicts_smart(config, self.grid_config)
        config = merge_dicts_smart(config, self.params)

        if self.distr_info:
            self.set_dist_env(config)
        return args, config

    def _fix_memory(self, experiment):
        if not torch.cuda.is_available():
            return
        max_memory = torch.cuda.get_device_properties(0).total_memory / (2**30)
        stages_config = experiment.stages_config
        for k, v in list(stages_config.items()):
            query = {}
            # noinspection PyProtectedMember
            for kk, vv in experiment._config['model_params'].items():
                query[kk] = vv
            for kk, vv in v['data_params'].items():
                query[kk] = vv
            variants = self.memory_provider.find(query)
            variants = [v for v in variants if v.memory < max_memory]
            if len(variants) == 0:
                continue
            variant = max(variants, key=lambda x: x.memory)
            v['data_params']['batch_size'] = variant.batch_size

    def _checkpoint_fix_config(self, experiment):
        resume = self.resume
        if not resume:
            return
        if experiment.logdir is None:
            return

        checkpoint_dir = join(experiment.logdir, 'checkpoints')
        os.makedirs(checkpoint_dir, exist_ok=True)

        file = 'last_full.pth' if resume.get('load_last') else 'best_full.pth'

        path = join(checkpoint_dir, file)
        computer = socket.gethostname()
        if computer != resume['master_computer']:
            master_computer = self.computer_provider.by_name(
                resume['master_computer'])
            path_from = join(master_computer.root_folder,
                             str(resume['master_task_id']), experiment.logdir,
                             'checkpoints', file)
            self.info(f'copying checkpoint from: computer = '
                      f'{resume["master_computer"]} path_from={path_from} '
                      f'path_to={path}')

            success = copy_remote(session=self.session,
                                  computer_from=resume['master_computer'],
                                  path_from=path_from,
                                  path_to=path)

            if not success:
                self.error(f'copying from '
                           f'{resume["master_computer"]}/'
                           f'{path_from} failed')
            else:
                self.info('checkpoint copied successfully')

        elif self.task.id != resume['master_task_id']:
            path = join(TASK_FOLDER, str(resume['master_task_id']),
                        experiment.logdir, 'checkpoints', file)
            self.info(f'master_task_id!=task.id, using checkpoint'
                      f' from task_id = {resume["master_task_id"]}')

        if not os.path.exists(path):
            self.info(f'no checkpoint at {path}')
            return

        ckpt = load_checkpoint(path)
        stages_config = experiment.stages_config
        for k, v in list(stages_config.items()):
            if k == ckpt['stage']:
                stage_epoch = ckpt['checkpoint_data']['epoch'] + 1

                # if it is the last epoch in the stage
                if stage_epoch >= v['state_params']['num_epochs'] \
                        or resume.get('load_best'):
                    del stages_config[k]
                    break

                self.checkpoint_stage_epoch = stage_epoch
                v['state_params']['num_epochs'] -= stage_epoch
                break
            del stages_config[k]

        stage = experiment.stages_config[experiment.stages[0]]
        for k, v in stage['callbacks_params'].items():
            if v.get('callback') == 'CheckpointCallback':
                v['resume'] = path

        self.info(f'found checkpoint at {path}')

    def _checkpoint_fix_callback(self, callbacks: dict):
        def mock(state):
            pass

        for k, c in callbacks.items():
            if not isinstance(c, CheckpointCallback):
                continue

            if c.resume:
                self.checkpoint_resume = True

            if not self.master:
                c.on_epoch_end = mock
                c.on_stage_end = mock
                c.on_batch_start = mock

    def work(self):
        args, config = self.parse_args_uargs()
        set_global_seed(args.seed)

        Experiment, R = import_experiment_and_runner(Path(args.expdir))

        runner_params = config.pop('runner_params', {})

        experiment = Experiment(config)
        runner: Runner = R(**runner_params)

        self.experiment = experiment
        self.runner = runner

        stages = experiment.stages[:]

        if self.task.parent:
            self.parent = self.task_provider.by_id(self.task.parent)

        if self.master:
            task = self.get_parent_task()
            task.steps = len(stages)
            self.task_provider.commit()

        self._checkpoint_fix_config(experiment)
        self._fix_memory(experiment)

        _get_callbacks = experiment.get_callbacks

        def get_callbacks(stage):
            res = self.callbacks()
            for k, v in _get_callbacks(stage).items():
                res[k] = v

            self._checkpoint_fix_callback(res)
            return res

        experiment.get_callbacks = get_callbacks

        if experiment.logdir is not None:
            dump_environment(config, experiment.logdir, args.configs)

        if self.distr_info:
            info = yaml_load(self.task.additional_info)
            info['resume'] = {
                'master_computer': self.distr_info['master_computer'],
                'master_task_id': self.task.id - self.distr_info['rank'],
                'load_best': True
            }
            self.task.additional_info = yaml_dump(info)
            self.task_provider.commit()

            experiment.stages_config = {
                k: v
                for k, v in experiment.stages_config.items()
                if k == experiment.stages[0]
            }

        runner.run_experiment(experiment)
        if runner.state.exception:
            raise runner.state.exception

        if self.master and self.trace:
            traced = trace_model_from_checkpoint(self.experiment.logdir, self)
            torch.jit.save(traced, self.trace)
        return {'stage': experiment.stages[-1], 'stages': stages}
Example #9
0
class ClassificationReportBuilder:
    def __init__(self,
                 session: Session,
                 task: Task,
                 layout: str,
                 part: str = 'valid',
                 name: str = 'img_classify',
                 max_img_size: Tuple[int, int] = None,
                 main_metric: str = 'accuracy',
                 plot_count: int = 0):
        self.session = session
        self.task = task
        self.layout = layout
        self.part = part
        self.name = name or 'img_classify'
        self.max_img_size = max_img_size
        self.main_metric = main_metric
        self.plot_count = plot_count

        self.dag_provider = DagProvider(session)
        self.report_provider = ReportProvider(session)
        self.layout_provider = ReportLayoutProvider(session)
        self.task_provider = TaskProvider(session)
        self.report_img_provider = ReportImgProvider(session)
        self.report_task_provider = ReportTasksProvider(session)
        self.report_series_provider = ReportSeriesProvider(session)

        self.project = self.task_provider.project(task.id).id
        self.layout = self.layout_provider.by_name(layout)
        self.layout_dict = yaml_load(self.layout.content)

    def create_base(self):
        report = Report(config=yaml_dump(self.layout_dict),
                        time=now(),
                        layout=self.layout.name,
                        project=self.project,
                        name=self.name)
        self.report_provider.add(report)
        self.report_task_provider.add(
            ReportTasks(report=report.id, task=self.task.id))

        self.task.report = report.id
        self.task_provider.update()

    def process_scores(self, scores):
        for key, item in self.layout_dict['items'].items():
            item['name'] = key
            if item['type'] == 'series' and item['key'] in scores:
                series = ReportSeries(name=item['name'],
                                      value=float(scores[item['key']]),
                                      epoch=0,
                                      time=now(),
                                      task=self.task.id,
                                      part='valid',
                                      stage='stage1')

                self.report_series_provider.add(series)

    def process_pred(self,
                     imgs: np.array,
                     preds: np.array,
                     targets: np.array = None,
                     attrs=None,
                     scores=None):
        for key, item in self.layout_dict['items'].items():
            item['name'] = key
            if item['type'] != 'img_classify':
                continue

            report_imgs = []
            dag = self.dag_provider.by_id(self.task.dag)

            for i in range(len(imgs)):
                if self.plot_count <= 0:
                    break

                img = resize_saving_ratio(imgs[i], self.max_img_size)
                pred = preds[i]
                attr = attrs[i] if attrs else {}

                y = None
                score = None
                if targets is not None:
                    y = targets[i]
                    score = float(scores[self.main_metric][i])

                y_pred = pred.argmax()
                retval, buffer = cv2.imencode('.jpg', img)
                report_img = ReportImg(group=item['name'],
                                       epoch=0,
                                       task=self.task.id,
                                       img=buffer,
                                       dag=self.task.dag,
                                       part=self.part,
                                       project=self.project,
                                       y_pred=y_pred,
                                       y=y,
                                       score=score,
                                       **attr)

                report_imgs.append(report_img)
                dag.img_size += report_img.size

            self.dag_provider.commit()
            self.report_img_provider.bulk_save_objects(report_imgs)

            if targets is not None and item.get('confusion_matrix'):
                matrix = confusion_matrix(targets,
                                          preds.argmax(axis=1),
                                          labels=np.arange(preds.shape[1]))
                matrix = np.array(matrix)
                c = {'data': matrix}
                obj = ReportImg(group=item['name'] + '_confusion',
                                epoch=0,
                                task=self.task.id,
                                img=pickle.dumps(c),
                                project=self.project,
                                dag=self.task.dag,
                                part=self.part)
                self.report_img_provider.add(obj)

            self.plot_count -= 1
Example #10
0
def describe(dag: int,
             metrics=None,
             last_n_epoch=None,
             computer: str = None,
             max_log_text: int = 45,
             fig_size=(12, 10),
             grid_spec: dict = None,
             log_count=5,
             log_col_widths: List[float] = None,
             wait=True,
             wait_interval=5,
             task_with_metric_count=0):
    grid_spec = grid_spec or {}
    metrics = metrics or []

    series_count = task_with_metric_count * len(metrics)
    size = (4 + ceil(series_count / 2), 2)
    default_grid_spec = {
        'tasks': {
            'rowspan': 1,
            'colspan': 2,
            'loc': (0, 0)
        },
        'dag': {
            'rowspan': 1,
            'colspan': 2,
            'loc': (1, 0)
        },
        'logs': {
            'rowspan': 1,
            'colspan': 2,
            'loc': (2, 0)
        },
        'resources': {
            'rowspan': 1,
            'colspan': 2,
            'loc': (3, 0)
        },
        'size': size
    }

    loc = (4, 0)
    for i in range(series_count):
        default_grid_spec[i] = {'rowspan': 1, 'colspan': 1, 'loc': loc}
        if loc[1] == 1:
            loc = (loc[0] + 1, 0)
        else:
            loc = (loc[0], 1)

    default_grid_spec.update(grid_spec)
    grid_spec = default_grid_spec

    fig = plt.figure(figsize=fig_size)

    def grid_cell(spec: dict):
        return plt.subplot2grid(size,
                                spec['loc'],
                                colspan=spec['colspan'],
                                rowspan=spec['rowspan'],
                                fig=fig)

    while True:
        computer = computer or gethostname()

        task_axis = grid_cell(grid_spec['tasks'])
        dag_axis = grid_cell(grid_spec['dag'])
        resources_axis = grid_cell(grid_spec['resources'])
        logs_axis = grid_cell(grid_spec['logs'])

        finish = describe_tasks(dag, task_axis)
        describe_dag(dag, dag_axis)
        errors = describe_logs(dag,
                               axis=logs_axis,
                               max_log_text=max_log_text,
                               log_count=log_count,
                               col_withds=log_col_widths)
        describe_resources(computer=computer, axis=resources_axis)

        series_provider = ReportSeriesProvider()
        series = series_provider.by_dag(dag, metrics)

        metric_axis = [grid_cell(grid_spec[i]) for i, s in enumerate(series)]

        describe_metrics(series, last_n_epoch=last_n_epoch, axis=metric_axis)

        plt.tight_layout()

        display.clear_output(wait=True)

        for error in errors:
            print(error.time)
            print(error.message)

        display.display(fig)

        if not wait or finish:
            break

        time.sleep(wait_interval)

    plt.close(fig)