def __enter__(self): self.tmp_dir_obj = tempfile.TemporaryDirectory(dir=self.tmp_dir_root) self.sample_dir = join(self.tmp_dir_obj.name, 'samples') make_dir(self.sample_dir) self.sample_ind = 0 return self
def save_model_bundle(self): model_bundle_dir = join(self.tmp_dir, 'model-bundle') make_dir(model_bundle_dir) shutil.copyfile(self.last_model_path, join(model_bundle_dir, 'model.pth')) shutil.copyfile(self.config_path, join(model_bundle_dir, 'config.json')) zipdir(model_bundle_dir, self.model_bundle_path)
def plot_batch(self, x, y, output_path, z=None): batch_sz = x.shape[0] ncols = nrows = math.ceil(math.sqrt(batch_sz)) fig = plt.figure(constrained_layout=True, figsize=(3 * ncols, 3 * nrows)) grid = gridspec.GridSpec(ncols=ncols, nrows=nrows, figure=fig) for i in range(batch_sz): ax = fig.add_subplot(grid[i]) if z is None: self.plot_xyz(ax, x[i], y[i]) else: self.plot_xyz(ax, x[i], y[i], z=z[i]) make_dir(output_path, use_dirname=True) plt.savefig(output_path) plt.close()
def __init__(self, cfg: LearnerConfig, tmp_dir, model_path=None): self.cfg = cfg self.tmp_dir = tmp_dir torch_cache_dir = '/opt/data/torch-cache' os.environ['TORCH_HOME'] = torch_cache_dir self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.data_cache_dir = '/opt/data/data-cache' make_dir(self.data_cache_dir) self.model = self.build_model() self.model.to(self.device) if model_path is not None: if isfile(model_path): self.model.load_state_dict( torch.load(model_path, map_location=self.device)) else: raise Exception( 'Model could not be found at {}'.format(model_path)) self.model.eval() else: log.info(self.cfg) self.train_ds = None self.train_dl = None self.valid_ds = None self.valid_dl = None self.test_ds = None self.test_dl = None if cfg.output_uri.startswith('s3://'): self.output_dir = get_local_path(cfg.output_uri, tmp_dir) make_dir(self.output_dir, force_empty=True) if not cfg.overfit_mode: self.sync_from_cloud() else: self.output_dir = cfg.output_uri make_dir(self.output_dir) self.last_model_path = join(self.output_dir, 'last-model.pth') self.config_path = join(self.output_dir, 'config.json') self.train_state_path = join(self.output_dir, 'train-state.json') self.log_path = join(self.output_dir, 'log.csv') model_bundle_fn = basename(cfg.get_model_bundle_uri()) self.model_bundle_path = join(self.output_dir, model_bundle_fn) self.metric_names = self.build_metric_names() json_to_file(self.cfg.dict(), self.config_path) self.load_init_weights() self.load_checkpoint() self.opt = self.build_optimizer() self.setup_data() self.start_epoch = self.get_start_epoch() self.steps_per_epoch = len( self.train_ds) // self.cfg.solver.batch_sz self.step_scheduler = self.build_step_scheduler() self.epoch_scheduler = self.build_epoch_scheduler() self.setup_tensorboard()
def write_sample(self, sample: DataSample): """ This writes a training or validation sample to (train|valid)/img/{scene_id}-{ind}.png and (train|valid)/labels/{scene_id}-{ind}.png """ split_name = 'train' if sample.is_train else 'valid' label_arr = sample.labels.get_label_arr(sample.window).astype(np.uint8) img_dir = join(self.sample_dir, split_name, 'img') labels_dir = join(self.sample_dir, split_name, 'labels') make_dir(img_dir) make_dir(labels_dir) img_path = join(img_dir, '{}-{}.png'.format(sample.scene_id, self.sample_ind)) labels_path = join( labels_dir, '{}-{}.png'.format(sample.scene_id, self.sample_ind)) save_img(sample.chip, img_path) save_img(label_arr, labels_path) self.sample_ind += 1
def setup_tensorboard(self): self.tb_writer = None if self.cfg.log_tensorboard: self.tb_log_dir = join(self.output_dir, 'tb-logs') make_dir(self.tb_log_dir) self.tb_writer = SummaryWriter(log_dir=self.tb_log_dir)
def save(self, labels): """Save. Args: labels - (SemanticSegmentationLabels) labels to be saved """ local_path = get_local_path(self.uri, self.tmp_dir) make_dir(local_path, use_dirname=True) transform = self.crs_transformer.get_affine_transform() crs = self.crs_transformer.get_image_crs() band_count = 1 dtype = np.uint8 if self.class_trans: band_count = 3 mask = (np.zeros((self.extent.ymax, self.extent.xmax), dtype=np.uint8) if self.vector_output else None) # https://github.com/mapbox/rasterio/blob/master/docs/quickstart.rst # https://rasterio.readthedocs.io/en/latest/topics/windowed-rw.html with rasterio.open(local_path, 'w', driver='GTiff', height=self.extent.ymax, width=self.extent.xmax, count=band_count, dtype=dtype, transform=transform, crs=crs) as dataset: for window in labels.get_windows(): label_arr = labels.get_label_arr(window) window = window.intersection(self.extent) label_arr = label_arr[0:window.get_height(), 0:window.get_width()] if mask is not None: mask[window.ymin:window.ymax, window.xmin:window.xmax] = label_arr window = window.rasterio_format() if self.class_trans: rgb_labels = self.class_trans.class_to_rgb(label_arr) for chan in range(3): dataset.write_band(chan + 1, rgb_labels[:, :, chan], window=window) else: img = label_arr.astype(dtype) dataset.write_band(1, img, window=window) upload_or_copy(local_path, self.uri) if self.vector_output: import mask_to_polygons.vectorification as vectorification import mask_to_polygons.processing.denoise as denoise for vo in self.vector_output: denoise_radius = vo.denoise uri = vo.uri mode = vo.get_mode() class_id = vo.class_id class_mask = np.array(mask == class_id, dtype=np.uint8) local_geojson_path = get_local_path(uri, self.tmp_dir) def transform(x, y): return self.crs_transformer.pixel_to_map((x, y)) if denoise_radius > 0: class_mask = denoise.denoise(class_mask, denoise_radius) if uri and mode == 'buildings': geojson = vectorification.geojson_from_mask( mask=class_mask, transform=transform, mode=mode, min_aspect_ratio=vo.min_aspect_ratio, min_area=vo.min_area, width_factor=vo.element_width_factor, thickness=vo.element_thickness) elif uri and mode == 'polygons': geojson = vectorification.geojson_from_mask( mask=class_mask, transform=transform, mode=mode) if local_geojson_path: with open(local_geojson_path, 'w') as file_out: file_out.write(geojson) upload_or_copy(local_geojson_path, uri)