Ejemplo n.º 1
0
def run(my_trainer: trainer.Trainer,
        watching_validator_gpu=None,
        write_val_to_stdout=False):
    if watching_validator_gpu is not None:
        atexit.register(find_and_kill_watcher, my_trainer.exporter.out_dir)
        pid, pidout_filename, writer = validator.offload_validation_to_watcher(
            my_trainer,
            watching_validator_gpu,
            as_subprocess=not DEBUG_WATCHER,
            write_val_to_stdout=write_val_to_stdout)
        atexit.unregister(find_and_kill_watcher)
        atexit.register(terminate_watcher, pid, writer)

    try:
        my_trainer.train()
        misc.color_text('Training is complete!', 'OKGREEN')
        atexit.unregister(query_remove_logdir)
    except KeyboardInterrupt:
        if y_or_n_input('I\'ve stopped training.  Finish script?',
                        default='y') == 'n':
            raise

    print('Evaluating final model')
    val_loss, eval_metrics, (segmentation_visualizations, score_visualizations) = \
        my_trainer.validate_split(should_export_visualizations=False)
    if eval_metrics is not None:
        eval_metrics = np.array(eval_metrics)
        eval_metrics *= 100
    viz = visualization_utils.get_tile_image(segmentation_visualizations)
    write_np_array_as_img(os.path.join(here, 'viz_evaluate.png'), viz)

    return eval_metrics
Ejemplo n.º 2
0
 def get_validation_progress_bar(self) -> ValProgressWatcher:
     watcher_log_dir = os.path.join(self.exporter.out_dir, 'model_checkpoints-val-log')
     if not os.path.exists(watcher_log_dir):
         print('Waiting for validator to start..')
         time.sleep(10)  # give subprocess a chance to make its log directory.
     if not os.path.exists(watcher_log_dir):
         t_val = None
         print('No directory exists at {}'.format(watcher_log_dir))
         if self.skip_validation:
             misc.color_text("Validation might not be happening: Couldn't find a watcher log directory at {}".format(
                 watcher_log_dir))
     else:
         t_val = ValProgressWatcher(watcher_log_directory=watcher_log_dir,
                                    trainer_model_directory=self.exporter.model_history_saver.model_checkpoint_dir)
     return t_val
Ejemplo n.º 3
0
def test_configure(dataset_name,
                   checkpoint_path,
                   config_idx,
                   sampler_name,
                   script_py_file='unknownscript.py',
                   cfg_override_args=None,
                   additional_logdir_tag=''):
    parent_directory = get_test_dir_parent_from_traindir(checkpoint_path)
    script_basename = os.path.basename(script_py_file).replace('.py', '')
    cfg, cfg_to_print = get_cfgs(dataset_name=dataset_name,
                                 config_idx=config_idx,
                                 cfg_override_args=cfg_override_args)
    if sampler_name is not None or 'sampler' not in cfg:
        cfg['sampler'] = sampler_name
    else:
        sampler_name = cfg['sampler']
    assert cfg['dataset'] == dataset_name, 'Debug Error: cfg[\'dataset\']: {}, ' \
                                           'args.dataset: {}'.format(cfg['dataset'], dataset_name)
    if cfg['dataset_instance_cap'] == 'match_model':
        cfg['dataset_instance_cap'] = cfg['n_instances_per_class']
    sampler_cfg = sampler_cfg_registry.get_sampler_cfg_set(sampler_name)
    out_dir = get_log_dir(os.path.join(parent_directory, script_basename),
                          cfg_to_print,
                          additional_tag=additional_logdir_tag)

    configs.save_config(out_dir, cfg)
    print(
        misc.color_text('logdir: {}'.format(out_dir), misc.TermColors.OKGREEN))
    return cfg, out_dir, sampler_cfg
Ejemplo n.º 4
0
def offload_validation_to_watcher(my_trainer,
                                  watching_validator_gpu,
                                  as_subprocess=True,
                                  write_val_to_stdout=False):
    assert my_trainer.t_val is None, 'Watcher already exists'
    starting_model_checkpoint = my_trainer.exporter.save_checkpoint(
        my_trainer.state.epoch,
        my_trainer.state.iteration,
        my_trainer.model,
        my_trainer.optim,
        my_trainer.best_mean_iu,
        mean_iu=None)
    pidout_filename = os.path.join(my_trainer.exporter.out_dir,
                                   'watcher_output_subprocess.log')
    writer = open(pidout_filename, 'wb')
    if not as_subprocess:  # debug
        validator = get_validator(my_trainer.exporter.out_dir,
                                  watching_validator_gpu,
                                  starting_model_checkpoint)
        watch_and_validate.main(
            my_trainer.exporter.out_dir,
            watching_validator_gpu,  # loops forever
            starting_model_checkpoint=starting_model_checkpoint)
        return
    else:
        # pid = subprocess.Popen(['screen', '-d', '-RR' '-d', '-m', 'python', 'scripts/watch_and_validate.py',
        #                         my_trainer.exporter.out_dir,
        #                         '--gpu', '{}'.format(watching_validator_gpu), '--starting_model_checkpoint',
        #                         starting_model_checkpoint], stdout=writer, stderr=subprocess.STDOUT)
        cmd = ' '.join([
            'python', 'scripts/watch_and_validate.py',
            my_trainer.exporter.out_dir, '--gpu',
            '{}'.format(watching_validator_gpu), '--starting_model_checkpoint',
            starting_model_checkpoint
        ])
        pid = start_screen_session_with_cmd(cmd)
    misc.color_text(
        'Offloaded validation to GPU {}'.format(watching_validator_gpu),
        color='OKBLUE')
    my_trainer.skip_validation = True
    my_trainer.t_val = my_trainer.get_validation_progress_bar()
    try:
        assert my_trainer.t_val is not None
    except AssertionError:
        print('Watcher didnt start.  Error log at {}'.format(pidout_filename))
        raise
    return pid, pidout_filename, writer
Ejemplo n.º 5
0
    def train(self):
        max_epoch = int(math.ceil(1. * self.state.max_iteration / len(self.dataloaders['train'])))
        if self.t_val is None:
            self.t_val = self.get_validation_progress_bar()

        for epoch in tqdm.trange(self.state.epoch, max_epoch,
                                 desc='Train', ncols=80, leave=False):
            self.state.epoch = epoch
            self.train_epoch()
            if self.state.training_complete():
                self.exporter.save_checkpoint(self.state.epoch, self.state.iteration, self.model, self.optim,
                                              self.best_mean_iu, None)
                break
        if self.t_val is not None:
            self.t_val.close()
            if not self.t_val.finished():
                misc.color_text('Validation is continuing.', color='WARNING')
            else:
                misc.color_text('Validation is continuing.', color='OKGREEN')
Ejemplo n.º 6
0
def main(replacement_dict_for_sys_args=None):
    script_utils.check_clean_work_tree()
    args, cfg_override_args = parse_args(replacement_dict_for_sys_args)
    cfg, out_dir, sampler_cfg = configure(dataset_name=args.dataset,
                                          config_idx=args.config,
                                          sampler_name=args.sampler,
                                          script_py_file=__file__,
                                          cfg_override_args=cfg_override_args)
    atexit.register(query_remove_logdir, out_dir)

    trainer_gpu = args.gpu
    watchingval_gpu = None if cfg['validation_gpu'] is None or len(cfg['validation_gpu']) == 0 \
        else int(cfg['validation_gpu'])

    if cfg['train_batch_size'] == 1 and len(trainer_gpu) > 1:
        print(
            misc.color_text(
                'Batch size is 1; another GPU won\'t speed things up.  We recommend assigning the other '
                'gpu to validation for speed: --validation_gpu <gpu_num>',
                'WARNING'))

    trainer = setup_train(args.dataset,
                          cfg,
                          out_dir,
                          sampler_cfg,
                          gpu=trainer_gpu,
                          checkpoint_path=args.resume,
                          semantic_init=args.semantic_init)

    if cfg['debug_dataloader_only']:
        n_debug_images = None if cfg['n_debug_images'] is None else int(
            cfg['n_debug_images'])
        debug_helper.debug_dataloader(trainer,
                                      split='train',
                                      n_debug_images=n_debug_images)
        atexit.unregister(query_remove_logdir)
    else:
        metrics = run(trainer, watchingval_gpu)
        # atexit.unregister(query_remove_logdir)
        if metrics is not None:
            print('''\
                Accuracy: {0}
                Accuracy Class: {1}
                Mean IU: {2}
                FWAV Accuracy: {3}'''.format(*metrics))
    return out_dir
def map_raw_inst_labels_to_instance_count(inst_lbl,
                                          sem_lbl_for_verification=None):
    """
    Specifically for Cityscapes.
    Warning: inst_lbl must be an int/long for this to work
    """
    orig_inst_lbl = inst_lbl.copy()
    inst_lbl[inst_lbl < 1000] = 0
    sem_lbl_of_objects = np.int32(inst_lbl / 1000)
    for sem_val in np.unique(sem_lbl_of_objects):
        if sem_val == 0:
            continue
        unique_inst_ids = sorted(
            np.unique(orig_inst_lbl[sem_lbl_of_objects == sem_val]))
        if max(unique_inst_ids) != sem_val * 1000 + len(unique_inst_ids) - 1:
            new_consecutive_inst_ids = range(
                1000 * sem_val, 1000 * sem_val + len(unique_inst_ids))
            print(
                misc.color_text(
                    'Instance values were in a weird format! Values present: {}.  Missing: {}'
                    .format(
                        unique_inst_ids,
                        set(new_consecutive_inst_ids) - set(unique_inst_ids)),
                    misc.TermColors.WARNING))
            fast_remap(inst_lbl,
                       old_vals=unique_inst_ids,
                       new_vals=new_consecutive_inst_ids)

    inst_lbl -= np.int32(sem_lbl_of_objects) * np.int32(1000)  # more efficient
    inst_lbl[orig_inst_lbl >= 1000] += 1

    if sem_lbl_for_verification is not None:
        try:
            sem_lbl_reconstructed = sem_lbl_of_objects
            assert np.all(
                (sem_lbl_reconstructed == 0) == (orig_inst_lbl < 1000))
            sem_lbl_reconstructed[orig_inst_lbl < 1000] = orig_inst_lbl[
                orig_inst_lbl < 1000]
            assert np.all(sem_lbl_reconstructed == sem_lbl_for_verification)
        except AssertionError:
            import ipdb
            ipdb.set_trace()
            raise

    return inst_lbl
Ejemplo n.º 8
0
def map_raw_inst_labels_to_instance_count(inst_lbl):
    """
    Specifically for Cityscapes.
    Warning: inst_lbl must be an int/long for this to work
    """
    inst_lbl[inst_lbl < 1000] = 0
    inst_lbl[inst_lbl >=
             1000] += 1  # We make instance labels start from 1 instead of 0.
    inst_lbl -= np.int32((inst_lbl / 1000)) * np.int32(1000)  # more efficient
    max_lbl = inst_lbl.max()
    if max_lbl > 0:
        # Check if instance values are consecutive, starting from 1.  If not, shift them all.
        consecutive_instance_values = list(range(1, max_lbl + 1))
        is_present = [
            bool((inst_lbl == val).sum() > 0)
            for val in consecutive_instance_values
        ]

        # Shift if needed
        if not all(is_present):
            old_instance_values = [
                val for val, p in zip(consecutive_instance_values, is_present)
                if p
            ]
            num_present = sum([1 for p in is_present if p])
            new_instance_values = list(range(1, num_present + 1))
            assert num_present < len(
                consecutive_instance_values), AssertionError('debug')
            assert num_present == len(new_instance_values) == len(
                old_instance_values), AssertionError('debug')
            print(
                misc.color_text(
                    'Instance values were in a weird format! Values present: {}'
                    .format(old_instance_values), misc.TermColors.WARNING))
            for old_val, new_val in zip(old_instance_values,
                                        new_instance_values):
                inst_lbl[inst_lbl == old_val] = new_val

    return inst_lbl