Esempio n. 1
0
    def __init__(self, split, **kwargs):
        self.data_path = coerce_to_path_and_check_exist(self.root / self.name / HDF5_FILE)
        self.split = split
        data, labels = load_hdf5_file(self.data_path)
        data = data.swapaxes(1, 3)  # NHWC
        if self.transposed:  # swap H and W axes
            data = data.swapaxes(1, 2)
        unique_labels = sorted(np.unique(labels))
        consecutive_labels = (np.diff(unique_labels) == 1).all()
        if not consecutive_labels:
            for k, l in enumerate(unique_labels, start=1):
                labels[labels == l] = k

        if split == 'val':
            n_val = round(VAL_SPLIT_RATIO * len(data))
            with use_seed(46):
                indices = np.random.choice(range(len(data)), n_val, replace=False)
            data, labels = data[indices], labels[indices]
        self.data, self.labels = data, labels
        self.size = len(self.labels)

        img_size = kwargs.get('img_size')
        if img_size is not None:
            self.img_size = (img_size, img_size) if isinstance(img_size, int) else img_size
            assert len(self.img_size) == 2
Esempio n. 2
0
    def __init__(self, split, img_size, tag, **kwargs):
        self.data_path = coerce_to_path_and_check_exist(
            self.root / self.name / tag) / split
        self.split = split
        self.tag = tag
        try:
            input_files = get_files_from_dir(self.data_path,
                                             IMG_EXTENSIONS,
                                             sort=True)
        except FileNotFoundError:
            input_files = []
        self.input_files = input_files
        self.labels = [-1] * len(input_files)
        self.n_classes = 0
        self.size = len(self.input_files)

        if isinstance(img_size, int):
            self.img_size = (img_size, img_size)
            self.crop = True
        else:
            assert len(img_size) == 2
            self.img_size = img_size
            self.crop = False

        if self.size > 0:
            sample_size = Image.open(self.input_files[0]).size
            if min(self.img_size) > min(sample_size):
                raise ValueError(
                    "img_size too big compared to a sampled image size, adjust it or upscale dataset"
                )
 def __init__(self, input_dir, output_dir, json_file='via_region_data.json', out_ext='png', color=ILLUSTRATION_COLOR,
              verbose=True):
     self.input_dir = coerce_to_path_and_check_exist(input_dir)
     self.annotations = self.load_json(self.input_dir / json_file)
     self.output_dir = coerce_to_path_and_create_dir(output_dir)
     self.out_ext = out_ext
     self.color = color
     self.mode = 'L' if isinstance(color, int) else 'RGB'
     self.background_color = 0 if isinstance(color, int) else (0, 0, 0)
     self.verbose = verbose
Esempio n. 4
0
def get_logger(log_dir, name):
    log_dir = coerce_to_path_and_check_exist(log_dir)
    logger = logging.getLogger(name)
    file_path = log_dir / "{}.log".format(name)
    hdlr = logging.FileHandler(file_path)
    formatter = logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s")
    hdlr.setFormatter(formatter)
    logger.addHandler(hdlr)
    logger.setLevel(logging.INFO)
    return logger
Esempio n. 5
0
 def __init__(self,
              input_dir,
              output_dir,
              color_label_mapping=COLOR_TO_LABEL_MAPPING,
              img_extension='png',
              verbose=True):
     self.input_dir = coerce_to_path_and_check_exist(input_dir)
     self.files = get_files_from_dir(self.input_dir,
                                     valid_extensions=img_extension)
     self.output_dir = coerce_to_path_and_create_dir(output_dir)
     self.color_label_mapping = color_label_mapping
     self.verbose = verbose
Esempio n. 6
0
 def __init__(self,
              input_dir,
              output_dir,
              suffix_fmt='-{}',
              out_ext='jpg',
              create_sub_dir=False,
              verbose=True):
     self.input_dir = coerce_to_path_and_check_exist(input_dir)
     self.files = get_files_from_dir(self.input_dir, valid_extensions='pdf')
     self.output_dir = coerce_to_path_and_create_dir(output_dir)
     self.suffix_fmt = suffix_fmt
     self.out_ext = out_ext
     self.create_sub_dir = create_sub_dir
     self.verbose = verbose
     if self.verbose:
         print_info("Pdf2Image initialised: found {} files".format(
             len(self.files)))
Esempio n. 7
0
def load_model_from_path(model_path,
                         dataset,
                         device=None,
                         attributes_to_return=None):
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    checkpoint = torch.load(coerce_to_path_and_check_exist(model_path),
                            map_location=device.type)
    model = get_model(checkpoint['model_name'])(dataset,
                                                **checkpoint['model_kwargs'])
    model = model.to(device)
    model.load_state_dict(safe_model_state_dict(checkpoint['model_state']))
    if attributes_to_return is not None:
        if isinstance(attributes_to_return, str):
            attributes_to_return = [attributes_to_return]
        return model, [checkpoint.get(key) for key in attributes_to_return]
    else:
        return model
Esempio n. 8
0
 def load_from_tag(self, tag, resume=False):
     self.print_and_log_info("Loading model from run {}".format(tag))
     path = coerce_to_path_and_check_exist(MODELS_PATH / tag / MODEL_FILE)
     checkpoint = torch.load(path, map_location=self.device)
     try:
         self.model.load_state_dict(checkpoint["model_state"])
     except RuntimeError:
         state = safe_model_state_dict(checkpoint["model_state"])
         self.model.module.load_state_dict(state)
     self.start_epoch, self.start_batch = 1, 1
     if resume:
         self.start_epoch, self.start_batch = checkpoint[
             "epoch"], checkpoint.get("batch", 0) + 1
         self.optimizer.load_state_dict(checkpoint["optimizer_state"])
         self.scheduler.load_state_dict(checkpoint["scheduler_state"])
         self.cur_lr = self.scheduler.get_lr()
     self.print_and_log_info(
         "Checkpoint loaded at epoch {}, batch {}".format(
             self.start_epoch, self.start_batch - 1))
Esempio n. 9
0
def load_model_from_path(model_path,
                         device=None,
                         attributes_to_return=None,
                         eval_mode=True):
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    checkpoint = torch.load(coerce_to_path_and_check_exist(model_path),
                            map_location=device.type)
    checkpoint['model_kwargs']['pretrained_encoder'] = False
    model = get_model(checkpoint['model_name'])(
        checkpoint['n_classes'], **checkpoint['model_kwargs']).to(device)
    model.load_state_dict(safe_model_state_dict(checkpoint['model_state']))
    if eval_mode:
        model.eval()
    if attributes_to_return is not None:
        if isinstance(attributes_to_return, str):
            attributes_to_return = [attributes_to_return]
        return model, [checkpoint.get(key) for key in attributes_to_return]
    else:
        return model
Esempio n. 10
0
 def __init__(self, split, restricted_labels, **kwargs):
     self.data_path = coerce_to_path_and_check_exist(self.root_path) / split
     self.split = split
     self.input_files, self.label_files = self._get_input_label_files()
     self.size = len(self.input_files)
     self.restricted_labels = sorted(restricted_labels)
     self.restricted_colors = [LABEL_TO_COLOR_MAPPING[l] for l in self.restricted_labels]
     self.label_idx_color_mapping = {self.restricted_labels.index(l) + 1: c
                                     for l, c in zip(self.restricted_labels, self.restricted_colors)}
     self.color_label_idx_mapping = {c: l for l, c in self.label_idx_color_mapping.items()}
     self.fill_background = BACKGROUND_LABEL in self.restricted_labels
     self.n_classes = len(self.restricted_labels) + 1
     self.img_size = kwargs.get('img_size')
     self.keep_aspect_ratio = kwargs.get('keep_aspect_ratio', True)
     self.baseline_dilation_iter = kwargs.get('baseline_dilation_iter', 1)
     self.normalize = kwargs.get('normalize', True)
     self.data_augmentation = kwargs.get('data_augmentation', True) and split == 'train'
     self.blur_radius_range = kwargs.get('blur_radius_range', BLUR_RADIUS_RANGE)
     self.brightness_factor_range = kwargs.get('brightness_factor_range', BRIGHTNESS_FACTOR_RANGE)
     self.contrast_factor_range = kwargs.get('contrast_factor_range', CONTRAST_FACTOR_RANGE)
     self.rotation_angle_range = kwargs.get('rotation_angle_range', ROTATION_ANGLE_RANGE)
     self.sampling_ratio_range = kwargs.get('sampling_ratio_range', SAMPLING_RATIO_RANGE)
     self.sampling_max_nb_pixels = kwargs.get('sampling_max_nb_pixels')
     self.transposition_weights = kwargs.get('transposition_weights', TRANPOSITION_WEIGHTS)
Esempio n. 11
0
 def load_json(json_file):
     json_file = coerce_to_path_and_check_exist(json_file)
     with open(json_file, mode='r') as f:
         result = json.load(f)
     return result
Esempio n. 12
0
    def __init__(self,
                 input_dir,
                 output_dir,
                 tag="default",
                 seg_fmt=SEG_GROUND_TRUTH_FMT,
                 labels_to_eval=None,
                 save_annotations=True,
                 labels_to_annot=None,
                 predict_bbox=False,
                 verbose=True):
        self.input_dir = coerce_to_path_and_check_exist(input_dir).absolute()
        self.files = get_files_from_dir(self.input_dir,
                                        valid_extensions=VALID_EXTENSIONS,
                                        recursive=True,
                                        sort=True)
        self.output_dir = coerce_to_path_and_create_dir(output_dir).absolute()
        self.seg_fmt = seg_fmt
        self.logger = get_logger(self.output_dir, name='evaluator')
        model_path = coerce_to_path_and_check_exist(MODELS_PATH / tag /
                                                    MODEL_FILE)
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.model, (self.img_size, restricted_labels,
                     self.normalize) = load_model_from_path(
                         model_path,
                         device=self.device,
                         attributes_to_return=[
                             'train_resolution', 'restricted_labels',
                             'normalize'
                         ])
        self.model.eval()

        self.restricted_labels = sorted(restricted_labels)
        self.labels_to_eval = [
            ILLUSTRATION_LABEL
        ] if labels_to_eval is None else sorted(labels_to_eval)
        self.labels_to_rm = set(self.restricted_labels).difference(
            self.labels_to_eval)
        assert len(
            set(self.labels_to_eval).intersection(
                self.restricted_labels)) == len(self.labels_to_eval)

        self.restricted_colors = [
            LABEL_TO_COLOR_MAPPING[l] for l in self.restricted_labels
        ]
        self.label_idx_color_mapping = {
            self.restricted_labels.index(l) + 1: c
            for l, c in zip(self.restricted_labels, self.restricted_colors)
        }
        self.color_label_idx_mapping = {
            c: l
            for l, c in self.label_idx_color_mapping.items()
        }

        self.metrics = defaultdict(lambda: RunningMetrics(
            self.restricted_labels, self.labels_to_eval))
        self.save_annotations = save_annotations
        self.labels_to_annot = labels_to_annot or self.labels_to_eval
        self.predict_bbox = predict_bbox
        self.verbose = verbose

        self.print_and_log_info('Output dir: {}'.format(
            self.output_dir.absolute()))
        self.print_and_log_info('Evaluator initialised with kwargs {}'.format({
            'labels_to_eval':
            self.labels_to_eval,
            'save_annotations':
            save_annotations
        }))
        self.print_and_log_info('Model tag: {}'.format(model_path.parent.name))
        self.print_and_log_info(
            'Model characteristics: train_resolution={}, restricted_labels={}'.
            format(self.img_size, self.restricted_labels))
        self.print_and_log_info('Found {} input files to process'.format(
            len(self.files)))
Esempio n. 13
0
                        nargs='+',
                        type=int,
                        default=[1],
                        help='Labels to eval')
    parser.add_argument('-s',
                        '--save_annot',
                        action='store_true',
                        help='Whether to save annotations')
    parser.add_argument('-lta',
                        '--labels_to_annot',
                        nargs='+',
                        type=int,
                        default=None,
                        help='Labels to annotate')
    parser.add_argument('-b',
                        '--pred_bbox',
                        action='store_true',
                        help='Whether to predict bounding boxes')
    args = parser.parse_args()

    input_dir = coerce_to_path_and_check_exist(args.input_dir)
    evaluator = Evaluator(input_dir,
                          args.output_dir,
                          tag=args.tag,
                          labels_to_eval=args.labels,
                          save_annotations=args.save_annot
                          if args.labels_to_annot is None else True,
                          labels_to_annot=args.labels_to_annot,
                          predict_bbox=args.pred_bbox)
    evaluator.run()
Esempio n. 14
0
 def save(self, prefix_name, output_dir):
     output_dir = coerce_to_path_and_check_exist(output_dir)
     self._save_images(prefix_name, output_dir)
     self._save_labels(prefix_name, output_dir)
Esempio n. 15
0
                    "\t".join(map("{:.4f}".format, scores.values())) + "\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Pipeline to train a NN model specified by a YML config")
    parser.add_argument("-t",
                        "--tag",
                        nargs="?",
                        type=str,
                        required=True,
                        help="Run tag of the experiment")
    parser.add_argument("-c",
                        "--config",
                        nargs="?",
                        type=str,
                        required=True,
                        help="Config file name")
    args = parser.parse_args()

    assert args.tag is not None and args.config is not None
    config = coerce_to_path_and_check_exist(CONFIGS_PATH / args.config)
    with open(config) as fp:
        cfg = yaml.load(fp, Loader=yaml.FullLoader)
    seed = cfg["training"].get("seed", 4321)
    dataset = cfg["dataset"]["name"]
    run_dir = RUNS_PATH / dataset / args.tag

    trainer = Trainer(config, run_dir, seed=seed)
    trainer.run(seed=seed)
Esempio n. 16
0
    def __init__(self, config_path, run_dir):
        self.config_path = coerce_to_path_and_check_exist(config_path)
        self.run_dir = coerce_to_path_and_create_dir(run_dir)
        self.logger = get_logger(self.run_dir, name="trainer")
        self.print_and_log_info(
            "Trainer initialisation: run directory is {}".format(run_dir))

        shutil.copy(self.config_path, self.run_dir)
        self.print_and_log_info("Config {} copied to run directory".format(
            self.config_path))

        with open(self.config_path) as fp:
            cfg = yaml.load(fp, Loader=yaml.FullLoader)

        if torch.cuda.is_available():
            type_device = "cuda"
            nb_device = torch.cuda.device_count()
            # XXX: set to False when input image sizes are not fixed
            torch.backends.cudnn.benchmark = cfg["training"].get(
                "cudnn_benchmark", True)

        else:
            type_device = "cpu"
            nb_device = None
        self.device = torch.device(type_device)
        self.print_and_log_info("Using {} device, nb_device is {}".format(
            type_device, nb_device))

        # Datasets and dataloaders
        self.dataset_kwargs = cfg["dataset"]
        self.dataset_name = self.dataset_kwargs.pop("name")
        train_dataset = get_dataset(self.dataset_name)("train",
                                                       **self.dataset_kwargs)
        val_dataset = get_dataset(self.dataset_name)("val",
                                                     **self.dataset_kwargs)
        self.restricted_labels = sorted(
            self.dataset_kwargs["restricted_labels"])
        self.n_classes = len(self.restricted_labels) + 1
        self.is_val_empty = len(val_dataset) == 0
        self.print_and_log_info("Dataset {} instantiated with {}".format(
            self.dataset_name, self.dataset_kwargs))
        self.print_and_log_info(
            "Found {} classes, {} train samples, {} val samples".format(
                self.n_classes, len(train_dataset), len(val_dataset)))

        self.batch_size = cfg["training"]["batch_size"]
        self.n_workers = cfg["training"]["n_workers"]
        self.train_loader = DataLoader(train_dataset,
                                       batch_size=self.batch_size,
                                       num_workers=self.n_workers,
                                       shuffle=True)
        self.val_loader = DataLoader(val_dataset,
                                     batch_size=self.batch_size,
                                     num_workers=self.n_workers)
        self.print_and_log_info(
            "Dataloaders instantiated with batch_size={} and n_workers={}".
            format(self.batch_size, self.n_workers))

        self.n_batches = len(self.train_loader)
        self.n_iterations, self.n_epoches = cfg["training"].get(
            "n_iterations"), cfg["training"].get("n_epoches")
        assert not (self.n_iterations is not None
                    and self.n_epoches is not None)
        if self.n_iterations is not None:
            self.n_epoches = max(self.n_iterations // self.n_batches, 1)
        else:
            self.n_iterations = self.n_epoches * len(self.train_loader)

        # Model
        self.model_kwargs = cfg["model"]
        self.model_name = self.model_kwargs.pop("name")
        model = get_model(self.model_name)(self.n_classes,
                                           **self.model_kwargs).to(self.device)
        self.model = torch.nn.DataParallel(model,
                                           device_ids=range(
                                               torch.cuda.device_count()))
        self.print_and_log_info("Using model {} with kwargs {}".format(
            self.model_name, self.model_kwargs))
        self.print_and_log_info('Number of trainable parameters: {}'.format(
            f'{count_parameters(self.model):,}'))

        # Optimizer
        optimizer_params = cfg["training"]["optimizer"] or {}
        optimizer_name = optimizer_params.pop("name", None)
        self.optimizer = get_optimizer(optimizer_name)(model.parameters(),
                                                       **optimizer_params)
        self.print_and_log_info("Using optimizer {} with kwargs {}".format(
            optimizer_name, optimizer_params))

        # Scheduler
        scheduler_params = cfg["training"].get("scheduler", {}) or {}
        scheduler_name = scheduler_params.pop("name", None)
        self.scheduler_update_range = scheduler_params.pop(
            "update_range", "epoch")
        assert self.scheduler_update_range in ["epoch", "batch"]
        if scheduler_name == "multi_step" and isinstance(
                scheduler_params["milestones"][0], float):
            n_tot = self.n_epoches if self.scheduler_update_range == "epoch" else self.n_iterations
            scheduler_params["milestones"] = [
                round(m * n_tot) for m in scheduler_params["milestones"]
            ]
        self.scheduler = get_scheduler(scheduler_name)(self.optimizer,
                                                       **scheduler_params)
        self.cur_lr = -1
        self.print_and_log_info("Using scheduler {} with parameters {}".format(
            scheduler_name, scheduler_params))

        # Loss
        loss_name = cfg["training"]["loss"]
        self.criterion = get_loss(loss_name)()
        self.print_and_log_info("Using loss {}".format(self.criterion))

        # Pretrained / Resume
        checkpoint_path = cfg["training"].get("pretrained")
        checkpoint_path_resume = cfg["training"].get("resume")
        assert not (checkpoint_path is not None
                    and checkpoint_path_resume is not None)
        if checkpoint_path is not None:
            self.load_from_tag(checkpoint_path)
        elif checkpoint_path_resume is not None:
            self.load_from_tag(checkpoint_path_resume, resume=True)
        else:
            self.start_epoch, self.start_batch = 1, 1

        # Train metrics
        train_iter_interval = cfg["training"].get(
            "train_stat_interval", self.n_epoches * self.n_batches // 200)
        self.train_stat_interval = train_iter_interval
        self.train_time = AverageMeter()
        self.train_loss = AverageMeter()
        self.train_metrics_path = self.run_dir / TRAIN_METRICS_FILE
        with open(self.train_metrics_path, mode="w") as f:
            f.write(
                "iteration\tepoch\tbatch\ttrain_loss\ttrain_time_per_img\n")

        # Val metrics
        val_iter_interval = cfg["training"].get(
            "val_stat_interval", self.n_epoches * self.n_batches // 100)
        self.val_stat_interval = val_iter_interval
        self.val_loss = AverageMeter()
        self.val_metrics = RunningMetrics(self.restricted_labels)
        self.val_current_score = None
        self.val_metrics_path = self.run_dir / VAL_METRICS_FILE
        with open(self.val_metrics_path, mode="w") as f:
            f.write("iteration\tepoch\tbatch\tval_loss\t" +
                    "\t".join(self.val_metrics.names) + "\n")
Esempio n. 17
0
 def convert(filename, dpi=100):
     filename = coerce_to_path_and_check_exist(filename)
     return convert_from_path(filename,
                              dpi=dpi,
                              use_cropbox=True,
                              fmt='jpg')
Esempio n. 18
0
 def load_json(json_file):
     json_file = coerce_to_path_and_check_exist(json_file)
     with open(json_file, mode='r') as f:
         result = json.load(f)
     return result.get('_via_img_metadata', result)
Esempio n. 19
0
 def __init__(self, input_dir=SYNTHETIC_RESRC_PATH):
     self.input_dir = coerce_to_path_and_check_exist(input_dir)
     self.table = self._initialize_table()
Esempio n. 20
0
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Pipeline to test a NN model on the test split of a dataset"
    )
    parser.add_argument("-t",
                        "--tag",
                        nargs="?",
                        type=str,
                        help="Model tag to test",
                        required=True)
    parser.add_argument("-d",
                        "--dataset",
                        nargs="?",
                        type=str,
                        default="syndoc",
                        help="Name of the dataset to test")
    args = parser.parse_args()

    run_dir = coerce_to_path_and_check_exist(MODELS_PATH / args.tag)
    output_dir = run_dir / "test_{}".format(args.dataset)
    config_path = list(run_dir.glob("*.yml"))[0]
    with open(config_path) as fp:
        cfg = yaml.load(fp, Loader=yaml.FullLoader)
    dataset_kwargs = cfg["dataset"]
    dataset_kwargs.pop("name")

    tester = Tester(output_dir, run_dir / MODEL_FILE, args.dataset,
                    dataset_kwargs)
    tester.run()
Esempio n. 21
0
    def __init__(self, config_path, run_dir):
        self.config_path = coerce_to_path_and_check_exist(config_path)
        self.run_dir = coerce_to_path_and_create_dir(run_dir)
        self.logger = get_logger(self.run_dir, name="trainer")
        self.print_and_log_info(
            "Trainer initialisation: run directory is {}".format(run_dir))

        shutil.copy(self.config_path, self.run_dir)
        self.print_and_log_info("Config {} copied to run directory".format(
            self.config_path))

        with open(self.config_path) as fp:
            cfg = yaml.load(fp, Loader=yaml.FullLoader)

        if torch.cuda.is_available():
            type_device = "cuda"
            nb_device = torch.cuda.device_count()
        else:
            type_device = "cpu"
            nb_device = None
        self.device = torch.device(type_device)
        self.print_and_log_info("Using {} device, nb_device is {}".format(
            type_device, nb_device))

        # Datasets and dataloaders
        self.dataset_kwargs = cfg["dataset"]
        self.dataset_name = self.dataset_kwargs.pop("name")
        train_dataset = get_dataset(self.dataset_name)("train",
                                                       **self.dataset_kwargs)
        val_dataset = get_dataset(self.dataset_name)("val",
                                                     **self.dataset_kwargs)
        self.n_classes = train_dataset.n_classes
        self.is_val_empty = len(val_dataset) == 0
        self.print_and_log_info("Dataset {} instantiated with {}".format(
            self.dataset_name, self.dataset_kwargs))
        self.print_and_log_info(
            "Found {} classes, {} train samples, {} val samples".format(
                self.n_classes, len(train_dataset), len(val_dataset)))

        self.img_size = train_dataset.img_size
        self.batch_size = cfg["training"]["batch_size"]
        self.n_workers = cfg["training"].get("n_workers", 4)
        self.train_loader = DataLoader(train_dataset,
                                       batch_size=self.batch_size,
                                       num_workers=self.n_workers,
                                       shuffle=True)
        self.val_loader = DataLoader(val_dataset,
                                     batch_size=self.batch_size,
                                     num_workers=self.n_workers)
        self.print_and_log_info(
            "Dataloaders instantiated with batch_size={} and n_workers={}".
            format(self.batch_size, self.n_workers))

        self.n_batches = len(self.train_loader)
        self.n_iterations, self.n_epoches = cfg["training"].get(
            "n_iterations"), cfg["training"].get("n_epoches")
        assert not (self.n_iterations is not None
                    and self.n_epoches is not None)
        if self.n_iterations is not None:
            self.n_epoches = max(self.n_iterations // self.n_batches, 1)
        else:
            self.n_iterations = self.n_epoches * len(self.train_loader)

        # Model
        self.model_kwargs = cfg["model"]
        self.model_name = self.model_kwargs.pop("name")
        self.is_gmm = 'gmm' in self.model_name
        self.model = get_model(self.model_name)(
            self.train_loader.dataset, **self.model_kwargs).to(self.device)
        self.print_and_log_info("Using model {} with kwargs {}".format(
            self.model_name, self.model_kwargs))
        self.print_and_log_info('Number of trainable parameters: {}'.format(
            f'{count_parameters(self.model):,}'))
        self.n_prototypes = self.model.n_prototypes

        # Optimizer
        opt_params = cfg["training"]["optimizer"] or {}
        optimizer_name = opt_params.pop("name")
        cluster_kwargs = opt_params.pop('cluster', {})
        tsf_kwargs = opt_params.pop('transformer', {})
        self.optimizer = get_optimizer(optimizer_name)([
            dict(params=self.model.cluster_parameters(), **cluster_kwargs),
            dict(params=self.model.transformer_parameters(), **tsf_kwargs)
        ], **opt_params)
        self.model.set_optimizer(self.optimizer)
        self.print_and_log_info("Using optimizer {} with kwargs {}".format(
            optimizer_name, opt_params))
        self.print_and_log_info("cluster kwargs {}".format(cluster_kwargs))
        self.print_and_log_info("transformer kwargs {}".format(tsf_kwargs))

        # Scheduler
        scheduler_params = cfg["training"].get("scheduler", {}) or {}
        scheduler_name = scheduler_params.pop("name", None)
        self.scheduler_update_range = scheduler_params.pop(
            "update_range", "epoch")
        assert self.scheduler_update_range in ["epoch", "batch"]
        if scheduler_name == "multi_step" and isinstance(
                scheduler_params["milestones"][0], float):
            n_tot = self.n_epoches if self.scheduler_update_range == "epoch" else self.n_iterations
            scheduler_params["milestones"] = [
                round(m * n_tot) for m in scheduler_params["milestones"]
            ]
        self.scheduler = get_scheduler(scheduler_name)(self.optimizer,
                                                       **scheduler_params)
        self.cur_lr = self.scheduler.get_last_lr()[0]
        self.print_and_log_info("Using scheduler {} with parameters {}".format(
            scheduler_name, scheduler_params))

        # Pretrained / Resume
        checkpoint_path = cfg["training"].get("pretrained")
        checkpoint_path_resume = cfg["training"].get("resume")
        assert not (checkpoint_path is not None
                    and checkpoint_path_resume is not None)
        if checkpoint_path is not None:
            self.load_from_tag(checkpoint_path)
        elif checkpoint_path_resume is not None:
            self.load_from_tag(checkpoint_path_resume, resume=True)
        else:
            self.start_epoch, self.start_batch = 1, 1

        # Train metrics & check_cluster interval
        metric_names = ['time/img', 'loss']
        metric_names += [f'prop_clus{i}' for i in range(self.n_prototypes)]
        train_iter_interval = cfg["training"]["train_stat_interval"]
        self.train_stat_interval = train_iter_interval
        self.train_metrics = Metrics(*metric_names)
        self.train_metrics_path = self.run_dir / TRAIN_METRICS_FILE
        with open(self.train_metrics_path, mode="w") as f:
            f.write("iteration\tepoch\tbatch\t" +
                    "\t".join(self.train_metrics.names) + "\n")
        self.check_cluster_interval = cfg["training"]["check_cluster_interval"]

        # Val metrics & scores
        val_iter_interval = cfg["training"]["val_stat_interval"]
        self.val_stat_interval = val_iter_interval
        self.val_metrics = Metrics('loss_val')
        self.val_metrics_path = self.run_dir / VAL_METRICS_FILE
        with open(self.val_metrics_path, mode="w") as f:
            f.write("iteration\tepoch\tbatch\t" +
                    "\t".join(self.val_metrics.names) + "\n")

        self.val_scores = Scores(self.n_classes, self.n_prototypes)
        self.val_scores_path = self.run_dir / VAL_SCORES_FILE
        with open(self.val_scores_path, mode="w") as f:
            f.write("iteration\tepoch\tbatch\t" +
                    "\t".join(self.val_scores.names) + "\n")

        # Prototypes & Variances
        self.prototypes_path = coerce_to_path_and_create_dir(self.run_dir /
                                                             'prototypes')
        [
            coerce_to_path_and_create_dir(self.prototypes_path / f'proto{k}')
            for k in range(self.n_prototypes)
        ]
        if self.is_gmm:
            self.variances_path = coerce_to_path_and_create_dir(self.run_dir /
                                                                'variances')
            [
                coerce_to_path_and_create_dir(self.variances_path / f'var{k}')
                for k in range(self.n_prototypes)
            ]

        # Transformation predictions
        self.transformation_path = coerce_to_path_and_create_dir(
            self.run_dir / 'transformations')
        self.images_to_tsf = next(iter(
            self.train_loader))[0][:N_TRANSFORMATION_PREDICTIONS].to(
                self.device)
        for k in range(self.images_to_tsf.size(0)):
            out = coerce_to_path_and_create_dir(self.transformation_path /
                                                f'img{k}')
            convert_to_img(self.images_to_tsf[k]).save(out / 'input.png')
            [
                coerce_to_path_and_create_dir(out / f'tsf{k}')
                for k in range(self.n_prototypes)
            ]

        # Visdom
        viz_port = cfg["training"].get("visualizer_port")
        if viz_port is not None:
            from visdom import Visdom
            os.environ["http_proxy"] = ""
            self.visualizer = Visdom(
                port=viz_port,
                env=f'{self.run_dir.parent.name}_{self.run_dir.name}')
            self.visualizer.delete_env(
                self.visualizer.env)  # Clean env before plotting
            self.print_and_log_info(f"Visualizer initialised at {viz_port}")
        else:
            self.visualizer = None
            self.print_and_log_info("No visualizer initialized")
Esempio n. 22
0
    def __init__(self,
                 input_dir,
                 output_dir,
                 labels_to_extract=None,
                 in_ext=VALID_EXTENSIONS,
                 out_ext='jpg',
                 tag='default',
                 save_annotations=True,
                 straight_bbox=False,
                 add_margin=True,
                 draw_margin=False,
                 verbose=True):
        self.input_dir = coerce_to_path_and_check_exist(input_dir).absolute()
        self.files = get_files_from_dir(self.input_dir,
                                        valid_extensions=in_ext,
                                        recursive=True,
                                        sort=True)
        self.output_dir = coerce_to_path_and_create_dir(output_dir).absolute()
        self.out_extension = out_ext
        self.logger = get_logger(self.output_dir, name='extractor')
        model_path = coerce_to_path_and_check_exist(MODELS_PATH / tag /
                                                    MODEL_FILE)
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.model, (self.img_size, restricted_labels,
                     self.normalize) = load_model_from_path(
                         model_path,
                         device=self.device,
                         attributes_to_return=[
                             'train_resolution', 'restricted_labels',
                             'normalize'
                         ])
        self.model.eval()

        self.restricted_labels = sorted(restricted_labels)
        self.labels_to_extract = [
            1, 4
        ] if labels_to_extract is None else sorted(labels_to_extract)
        if not set(self.labels_to_extract).issubset(self.restricted_labels):
            raise ValueError(
                'Incompatible `labels_to_extract` and `tag` arguments: '
                f'model was trained using {self.restricted_labels} labels only'
            )

        self.save_annotations = save_annotations
        self.straight_bbox = straight_bbox
        self.add_margin = add_margin
        self.draw_margin = add_margin and draw_margin
        self.verbose = verbose
        self.print_and_log_info('Extractor initialised with kwargs {}'.format({
            'tag':
            tag,
            'labels_to_extract':
            self.labels_to_extract,
            'save_annotations':
            save_annotations,
            'straight_bbox':
            straight_bbox,
            'add_margin':
            add_margin,
            'draw_margin':
            draw_margin
        }))
        self.print_and_log_info(
            'Model characteristics: train_resolution={}, restricted_labels={}'.
            format(self.img_size, self.restricted_labels))
        self.print_and_log_info('Found {} input files to process'.format(
            len(self.files)))