class Trainer: """Pipeline to train a NN model using a certain dataset, both specified by an YML config.""" @use_seed() def __init__(self, config_path, run_dir): self.config_path = coerce_to_path_and_check_exist(config_path) self.run_dir = coerce_to_path_and_create_dir(run_dir) self.logger = get_logger(self.run_dir, name="trainer") self.print_and_log_info( "Trainer initialisation: run directory is {}".format(run_dir)) shutil.copy(self.config_path, self.run_dir) self.print_and_log_info("Config {} copied to run directory".format( self.config_path)) with open(self.config_path) as fp: cfg = yaml.load(fp, Loader=yaml.FullLoader) if torch.cuda.is_available(): type_device = "cuda" nb_device = torch.cuda.device_count() else: type_device = "cpu" nb_device = None self.device = torch.device(type_device) self.print_and_log_info("Using {} device, nb_device is {}".format( type_device, nb_device)) # Datasets and dataloaders self.dataset_kwargs = cfg["dataset"] self.dataset_name = self.dataset_kwargs.pop("name") train_dataset = get_dataset(self.dataset_name)("train", **self.dataset_kwargs) val_dataset = get_dataset(self.dataset_name)("val", **self.dataset_kwargs) self.n_classes = train_dataset.n_classes self.is_val_empty = len(val_dataset) == 0 self.print_and_log_info("Dataset {} instantiated with {}".format( self.dataset_name, self.dataset_kwargs)) self.print_and_log_info( "Found {} classes, {} train samples, {} val samples".format( self.n_classes, len(train_dataset), len(val_dataset))) self.img_size = train_dataset.img_size self.batch_size = cfg["training"]["batch_size"] self.n_workers = cfg["training"].get("n_workers", 4) self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size, num_workers=self.n_workers, shuffle=True) self.val_loader = DataLoader(val_dataset, batch_size=self.batch_size, num_workers=self.n_workers) self.print_and_log_info( "Dataloaders instantiated with batch_size={} and n_workers={}". format(self.batch_size, self.n_workers)) self.n_batches = len(self.train_loader) self.n_iterations, self.n_epoches = cfg["training"].get( "n_iterations"), cfg["training"].get("n_epoches") assert not (self.n_iterations is not None and self.n_epoches is not None) if self.n_iterations is not None: self.n_epoches = max(self.n_iterations // self.n_batches, 1) else: self.n_iterations = self.n_epoches * len(self.train_loader) # Model self.model_kwargs = cfg["model"] self.model_name = self.model_kwargs.pop("name") self.is_gmm = 'gmm' in self.model_name self.model = get_model(self.model_name)( self.train_loader.dataset, **self.model_kwargs).to(self.device) self.print_and_log_info("Using model {} with kwargs {}".format( self.model_name, self.model_kwargs)) self.print_and_log_info('Number of trainable parameters: {}'.format( f'{count_parameters(self.model):,}')) self.n_prototypes = self.model.n_prototypes # Optimizer opt_params = cfg["training"]["optimizer"] or {} optimizer_name = opt_params.pop("name") cluster_kwargs = opt_params.pop('cluster', {}) tsf_kwargs = opt_params.pop('transformer', {}) self.optimizer = get_optimizer(optimizer_name)([ dict(params=self.model.cluster_parameters(), **cluster_kwargs), dict(params=self.model.transformer_parameters(), **tsf_kwargs) ], **opt_params) self.model.set_optimizer(self.optimizer) self.print_and_log_info("Using optimizer {} with kwargs {}".format( optimizer_name, opt_params)) self.print_and_log_info("cluster kwargs {}".format(cluster_kwargs)) self.print_and_log_info("transformer kwargs {}".format(tsf_kwargs)) # Scheduler scheduler_params = cfg["training"].get("scheduler", {}) or {} scheduler_name = scheduler_params.pop("name", None) self.scheduler_update_range = scheduler_params.pop( "update_range", "epoch") assert self.scheduler_update_range in ["epoch", "batch"] if scheduler_name == "multi_step" and isinstance( scheduler_params["milestones"][0], float): n_tot = self.n_epoches if self.scheduler_update_range == "epoch" else self.n_iterations scheduler_params["milestones"] = [ round(m * n_tot) for m in scheduler_params["milestones"] ] self.scheduler = get_scheduler(scheduler_name)(self.optimizer, **scheduler_params) self.cur_lr = self.scheduler.get_last_lr()[0] self.print_and_log_info("Using scheduler {} with parameters {}".format( scheduler_name, scheduler_params)) # Pretrained / Resume checkpoint_path = cfg["training"].get("pretrained") checkpoint_path_resume = cfg["training"].get("resume") assert not (checkpoint_path is not None and checkpoint_path_resume is not None) if checkpoint_path is not None: self.load_from_tag(checkpoint_path) elif checkpoint_path_resume is not None: self.load_from_tag(checkpoint_path_resume, resume=True) else: self.start_epoch, self.start_batch = 1, 1 # Train metrics & check_cluster interval metric_names = ['time/img', 'loss'] metric_names += [f'prop_clus{i}' for i in range(self.n_prototypes)] train_iter_interval = cfg["training"]["train_stat_interval"] self.train_stat_interval = train_iter_interval self.train_metrics = Metrics(*metric_names) self.train_metrics_path = self.run_dir / TRAIN_METRICS_FILE with open(self.train_metrics_path, mode="w") as f: f.write("iteration\tepoch\tbatch\t" + "\t".join(self.train_metrics.names) + "\n") self.check_cluster_interval = cfg["training"]["check_cluster_interval"] # Val metrics & scores val_iter_interval = cfg["training"]["val_stat_interval"] self.val_stat_interval = val_iter_interval self.val_metrics = Metrics('loss_val') self.val_metrics_path = self.run_dir / VAL_METRICS_FILE with open(self.val_metrics_path, mode="w") as f: f.write("iteration\tepoch\tbatch\t" + "\t".join(self.val_metrics.names) + "\n") self.val_scores = Scores(self.n_classes, self.n_prototypes) self.val_scores_path = self.run_dir / VAL_SCORES_FILE with open(self.val_scores_path, mode="w") as f: f.write("iteration\tepoch\tbatch\t" + "\t".join(self.val_scores.names) + "\n") # Prototypes & Variances self.prototypes_path = coerce_to_path_and_create_dir(self.run_dir / 'prototypes') [ coerce_to_path_and_create_dir(self.prototypes_path / f'proto{k}') for k in range(self.n_prototypes) ] if self.is_gmm: self.variances_path = coerce_to_path_and_create_dir(self.run_dir / 'variances') [ coerce_to_path_and_create_dir(self.variances_path / f'var{k}') for k in range(self.n_prototypes) ] # Transformation predictions self.transformation_path = coerce_to_path_and_create_dir( self.run_dir / 'transformations') self.images_to_tsf = next(iter( self.train_loader))[0][:N_TRANSFORMATION_PREDICTIONS].to( self.device) for k in range(self.images_to_tsf.size(0)): out = coerce_to_path_and_create_dir(self.transformation_path / f'img{k}') convert_to_img(self.images_to_tsf[k]).save(out / 'input.png') [ coerce_to_path_and_create_dir(out / f'tsf{k}') for k in range(self.n_prototypes) ] # Visdom viz_port = cfg["training"].get("visualizer_port") if viz_port is not None: from visdom import Visdom os.environ["http_proxy"] = "" self.visualizer = Visdom( port=viz_port, env=f'{self.run_dir.parent.name}_{self.run_dir.name}') self.visualizer.delete_env( self.visualizer.env) # Clean env before plotting self.print_and_log_info(f"Visualizer initialised at {viz_port}") else: self.visualizer = None self.print_and_log_info("No visualizer initialized") def print_and_log_info(self, string): print_info(string) self.logger.info(string) def load_from_tag(self, tag, resume=False): self.print_and_log_info("Loading model from run {}".format(tag)) path = coerce_to_path_and_check_exist(RUNS_PATH / self.dataset_name / tag / MODEL_FILE) checkpoint = torch.load(path, map_location=self.device) try: self.model.load_state_dict(checkpoint["model_state"]) except RuntimeError: state = safe_model_state_dict(checkpoint["model_state"]) self.model.module.load_state_dict(state) self.start_epoch, self.start_batch = 1, 1 if resume: self.start_epoch, self.start_batch = checkpoint[ "epoch"], checkpoint.get("batch", 0) + 1 self.optimizer.load_state_dict(checkpoint["optimizer_state"]) self.scheduler.load_state_dict(checkpoint["scheduler_state"]) self.cur_lr = self.scheduler.get_last_lr()[0] self.print_and_log_info( "Checkpoint loaded at epoch {}, batch {}".format( self.start_epoch, self.start_batch - 1)) @property def score_name(self): return self.val_scores.score_name def print_memory_usage(self, prefix): usage = {} for attr in [ "memory_allocated", "max_memory_allocated", "memory_cached", "max_memory_cached" ]: usage[attr] = getattr(torch.cuda, attr)() * 0.000001 self.print_and_log_info("{}:\t{}".format( prefix, " / ".join( ["{}: {:.0f}MiB".format(k, v) for k, v in usage.items()]))) @use_seed() def run(self): cur_iter = (self.start_epoch - 1) * self.n_batches + self.start_batch - 1 prev_train_stat_iter, prev_val_stat_iter = cur_iter, cur_iter prev_check_cluster_iter = cur_iter if self.start_epoch == self.n_epoches: self.print_and_log_info("No training, only evaluating") self.evaluate() self.print_and_log_info("Training run is over") return None for epoch in range(self.start_epoch, self.n_epoches + 1): batch_start = self.start_batch if epoch == self.start_epoch else 1 for batch, (images, labels) in enumerate(self.train_loader, start=1): if batch < batch_start: continue cur_iter += 1 if cur_iter > self.n_iterations: break self.single_train_batch_run(images) if self.scheduler_update_range == "batch": self.update_scheduler(epoch, batch=batch) if (cur_iter - prev_train_stat_iter) >= self.train_stat_interval: prev_train_stat_iter = cur_iter self.log_train_metrics(cur_iter, epoch, batch) if (cur_iter - prev_check_cluster_iter ) >= self.check_cluster_interval: prev_check_cluster_iter = cur_iter self.check_cluster(cur_iter, epoch, batch) if (cur_iter - prev_val_stat_iter) >= self.val_stat_interval: prev_val_stat_iter = cur_iter if not self.is_val_empty: self.run_val() self.log_val_metrics(cur_iter, epoch, batch) self.log_images(cur_iter) self.save(epoch=epoch, batch=batch) self.model.step() if self.scheduler_update_range == "epoch" and batch_start == 1: self.update_scheduler(epoch + 1, batch=1) self.save_training_metrics() self.evaluate() self.print_and_log_info("Training run is over") def update_scheduler(self, epoch, batch): self.scheduler.step() lr = self.scheduler.get_last_lr()[0] if lr != self.cur_lr: self.cur_lr = lr self.print_and_log_info( PRINT_LR_UPD_FMT(epoch, self.n_epoches, batch, self.n_batches, lr)) def single_train_batch_run(self, images): start_time = time.time() B = images.size(0) self.model.train() images = images.to(self.device) self.optimizer.zero_grad() loss, distances = self.model(images) loss.backward() self.optimizer.step() with torch.no_grad(): argmin_idx = distances.min(1)[1] mask = torch.zeros(B, self.n_prototypes, device=self.device).scatter( 1, argmin_idx[:, None], 1) proportions = mask.sum(0).cpu().numpy() / B self.train_metrics.update({ 'time/img': (time.time() - start_time) / B, 'loss': loss.item(), }) self.train_metrics.update( {f'prop_clus{i}': p for i, p in enumerate(proportions)}) @torch.no_grad() def log_images(self, cur_iter): self.save_prototypes(cur_iter) self.update_visualizer_images(self.model.prototypes, 'prototypes', nrow=5) if self.is_gmm: self.save_variances(cur_iter) variances = self.model.variances M = variances.flatten(1).max(1)[0][:, None, None, None] variances = (variances - self.model.var_min) / (M - self.model.var_min + 1e-7) self.update_visualizer_images(variances, 'variances', nrow=5) tsf_imgs = self.save_transformed_images(cur_iter) C, H, W = tsf_imgs.shape[2:] self.update_visualizer_images(tsf_imgs.view(-1, C, H, W), 'transformations', nrow=self.n_prototypes + 1) def save_prototypes(self, cur_iter=None): prototypes = self.model.prototypes for k in range(self.n_prototypes): img = convert_to_img(prototypes[k]) if cur_iter is not None: img.save(self.prototypes_path / f'proto{k}' / f'{cur_iter}.jpg') else: img.save(self.prototypes_path / f'prototype{k}.png') def save_variances(self, cur_iter=None): variances = self.model.variances for k in range(self.n_prototypes): img = convert_to_img(variances[k]) if cur_iter is not None: img.save(self.variances_path / f'var{k}' / f'{cur_iter}.jpg') else: img.save(self.variances_path / f'variance{k}.png') @torch.no_grad() def save_transformed_images(self, cur_iter=None): self.model.eval() output = self.model.transform(self.images_to_tsf) transformed_imgs = torch.cat([self.images_to_tsf.unsqueeze(1), output], 1) for k in range(transformed_imgs.size(0)): for j, img in enumerate(transformed_imgs[k][1:]): if cur_iter is not None: convert_to_img(img).save(self.transformation_path / f'img{k}' / f'tsf{j}' / f'{cur_iter}.jpg') else: convert_to_img(img).save(self.transformation_path / f'img{k}' / f'tsf{j}.png') return transformed_imgs def update_visualizer_images(self, images, title, nrow): if self.visualizer is None: return None if max(images.shape[1:]) > VIZ_MAX_IMG_SIZE: images = torch.nn.functional.interpolate(images, size=VIZ_MAX_IMG_SIZE, mode='bilinear') self.visualizer.images(images.clamp(0, 1), win=title, nrow=nrow, opts=dict(title=title, store_history=True, width=VIZ_WIDTH, height=VIZ_HEIGHT)) def check_cluster(self, cur_iter, epoch, batch): proportions = [ self.train_metrics[f'prop_clus{i}'].avg for i in range(self.n_prototypes) ] reassigned, idx = self.model.reassign_empty_clusters(proportions) msg = PRINT_CHECK_CLUSTERS_FMT(epoch, self.n_epoches, batch, self.n_batches, reassigned, idx) self.print_and_log_info(msg) self.train_metrics.reset( *[f'prop_clus{i}' for i in range(self.n_prototypes)]) def log_train_metrics(self, cur_iter, epoch, batch): # Print & write metrics to file stat = PRINT_TRAIN_STAT_FMT(epoch, self.n_epoches, batch, self.n_batches, self.train_metrics) self.print_and_log_info(stat) with open(self.train_metrics_path, mode="a") as f: f.write("{}\t{}\t{}\t".format(cur_iter, epoch, batch) + "\t".join( map("{:.4f}".format, self.train_metrics.avg_values)) + "\n") self.update_visualizer_metrics(cur_iter, train=True) self.train_metrics.reset('time/img', 'loss') def update_visualizer_metrics(self, cur_iter, train): if self.visualizer is None: return None split, metrics = ('train', self.train_metrics) if train else ('val', self.val_metrics) loss_names = [n for n in metrics.names if 'loss' in n] y, x = [[metrics[n].avg for n in loss_names]], [[cur_iter] * len(loss_names)] self.visualizer.line(y, x, win=f'{split}_loss', update='append', opts=dict(title=f'{split}_loss', legend=loss_names, width=VIZ_WIDTH, height=VIZ_HEIGHT)) if train: if self.n_prototypes > 1: # Cluster proportions proportions = [ metrics[f'prop_clus{i}'].avg for i in range(self.n_prototypes) ] self.visualizer.bar(proportions, win=f'train_cluster_prop', opts=dict( title=f'train_cluster_proportions', width=VIZ_HEIGHT, height=VIZ_HEIGHT)) else: # Scores names = list( filter(lambda name: 'cls' not in name, self.val_scores.names)) y, x = [[self.val_scores[n] for n in names]], [[cur_iter] * len(names)] self.visualizer.line(y, x, win=f'global_scores', update='append', opts=dict(title=f'global_scores', legend=names, width=VIZ_WIDTH, height=VIZ_HEIGHT)) y, x = [[ self.val_scores[f'acc_cls{i}'] for i in range(self.n_classes) ]], [[cur_iter] * self.n_classes] self.visualizer.line( y, x, win=f'acc_by_cls', update='append', opts=dict(title=f'acc_by_cls', legend=[f'cls{i}' for i in range(self.n_classes)], width=VIZ_WIDTH, heigh=VIZ_HEIGHT)) @torch.no_grad() def run_val(self): self.model.eval() for images, labels in self.val_loader: images = images.to(self.device) distances = self.model(images)[1] dist_min_by_sample, argmin_idx = distances.min(1) loss_val = dist_min_by_sample.mean() self.val_metrics.update({'loss_val': loss_val.item()}) self.val_scores.update(labels.long().numpy(), argmin_idx.cpu().numpy()) def log_val_metrics(self, cur_iter, epoch, batch): stat = PRINT_VAL_STAT_FMT(epoch, self.n_epoches, batch, self.n_batches, self.val_metrics) self.print_and_log_info(stat) with open(self.val_metrics_path, mode="a") as f: f.write( "{}\t{}\t{}\t".format(cur_iter, epoch, batch) + "\t".join(map("{:.4f}".format, self.val_metrics.avg_values)) + "\n") scores = self.val_scores.compute() self.print_and_log_info( "val_scores: " + ", ".join(["{}={:.4f}".format(k, v) for k, v in scores.items()])) with open(self.val_scores_path, mode="a") as f: f.write("{}\t{}\t{}\t".format(cur_iter, epoch, batch) + "\t".join(map("{:.4f}".format, scores.values())) + "\n") self.update_visualizer_metrics(cur_iter, train=False) self.val_scores.reset() self.val_metrics.reset() def save(self, epoch, batch): state = { "epoch": epoch, "batch": batch, "model_name": self.model_name, "model_kwargs": self.model_kwargs, "model_state": self.model.state_dict(), "n_prototypes": self.n_prototypes, "optimizer_state": self.optimizer.state_dict(), "scheduler_state": self.scheduler.state_dict(), } save_path = self.run_dir / MODEL_FILE torch.save(state, save_path) self.print_and_log_info("Model saved at {}".format(save_path)) def save_training_metrics(self): df_train = pd.read_csv(self.train_metrics_path, sep="\t", index_col=0) df_val = pd.read_csv(self.val_metrics_path, sep="\t", index_col=0) df_scores = pd.read_csv(self.val_scores_path, sep="\t", index_col=0) if len(df_train) == 0: self.print_and_log_info("No metrics or plots to save") return # Losses losses = list( filter(lambda s: s.startswith('loss'), self.train_metrics.names)) df = df_train.join(df_val[['loss_val']], how="outer") fig = plot_lines(df, losses + ['loss_val'], title="Loss") fig.savefig(self.run_dir / "loss.pdf") # Cluster proportions names = list( filter(lambda s: s.startswith('prop_'), self.train_metrics.names)) fig = plot_lines(df, names, title="Cluster proportions") fig.savefig(self.run_dir / "cluster_proportions.pdf") s = df[names].iloc[-1] s.index = list(map(lambda n: n.replace('prop_clus', ''), names)) fig = plot_bar(s, title="Final cluster proportions") fig.savefig(self.run_dir / "cluster_proportions_final.pdf") # Validation if not self.is_val_empty: names = list( filter(lambda name: 'cls' not in name, self.val_scores.names)) fig = plot_lines(df_scores, names, title="Global scores", unit_yaxis=True) fig.savefig(self.run_dir / 'global_scores.pdf') fig = plot_lines(df_scores, [f'acc_cls{i}' for i in range(self.n_classes)], title="Scores by cls", unit_yaxis=True) fig.savefig(self.run_dir / "scores_by_cls.pdf") # Prototypes & Variances size = MAX_GIF_SIZE if MAX_GIF_SIZE < max( self.img_size) else self.img_size self.save_prototypes() if self.is_gmm: self.save_variances() for k in range(self.n_prototypes): save_gif(self.prototypes_path / f'proto{k}', f'prototype{k}.gif', size=size) shutil.rmtree(str(self.prototypes_path / f'proto{k}')) if self.is_gmm: save_gif(self.variances_path / f'var{k}', f'variance{k}.gif', size=size) shutil.rmtree(str(self.variances_path / f'var{k}')) # Transformation predictions if self.model.transformer.is_identity: # no need to keep transformation predictions shutil.rmtree(str(self.transformation_path)) coerce_to_path_and_create_dir(self.transformation_path) else: self.save_transformed_images() for i in range(self.images_to_tsf.size(0)): for k in range(self.n_prototypes): save_gif(self.transformation_path / f'img{i}' / f'tsf{k}', f'tsf{k}.gif', size=size) shutil.rmtree( str(self.transformation_path / f'img{i}' / f'tsf{k}')) self.print_and_log_info("Training metrics and visuals saved") def evaluate(self): self.model.eval() no_label = self.train_loader.dataset[0][1] == -1 if no_label: self.qualitative_eval() else: self.quantitative_eval() self.print_and_log_info("Evaluation is over") @torch.no_grad() def qualitative_eval(self): """Routine to save qualitative results""" loss = AverageMeter() scores_path = self.run_dir / FINAL_SCORES_FILE with open(scores_path, mode="w") as f: f.write("loss\n") cluster_path = coerce_to_path_and_create_dir(self.run_dir / 'clusters') dataset = self.train_loader.dataset train_loader = DataLoader(dataset, batch_size=self.batch_size, num_workers=self.n_workers, shuffle=False) # Compute results distances, cluster_idx = np.array([]), np.array([], dtype=np.int32) averages = {k: AverageTensorMeter() for k in range(self.n_prototypes)} for images, labels in train_loader: images = images.to(self.device) dist = self.model(images)[1] dist_min_by_sample, argmin_idx = map(lambda t: t.cpu().numpy(), dist.min(1)) loss.update(dist_min_by_sample.mean(), n=len(dist_min_by_sample)) argmin_idx = argmin_idx.astype(np.int32) distances = np.hstack([distances, dist_min_by_sample]) cluster_idx = np.hstack([cluster_idx, argmin_idx]) transformed_imgs = self.model.transform(images).cpu() for k in range(self.n_prototypes): imgs = transformed_imgs[argmin_idx == k, k] averages[k].update(imgs) self.print_and_log_info("final_loss: {:.5}".format(loss.avg)) with open(scores_path, mode="a") as f: f.write("{:.5}\n".format(loss.avg)) # Save results with open(cluster_path / 'cluster_counts.tsv', mode='w') as f: f.write('\t'.join([str(k) for k in range(self.n_prototypes)]) + '\n') f.write('\t'.join( [str(averages[k].count) for k in range(self.n_prototypes)]) + '\n') for k in range(self.n_prototypes): path = coerce_to_path_and_create_dir(cluster_path / f'cluster{k}') indices = np.where(cluster_idx == k)[0] top_idx = np.argsort(distances[indices])[:N_CLUSTER_SAMPLES] for j, idx in enumerate(top_idx): inp = dataset[indices[idx]][0].unsqueeze(0).to(self.device) convert_to_img(inp).save(path / f'top{j}_raw.png') if not self.model.transformer.is_identity: convert_to_img(self.model.transform(inp)[0, k]).save( path / f'top{j}_tsf.png') convert_to_img( self.model.transform(inp, inverse=True)[0, k]).save( path / f'top{j}_tsf_inp.png') if len(indices) <= N_CLUSTER_SAMPLES: random_idx = indices else: random_idx = np.random.choice(indices, N_CLUSTER_SAMPLES, replace=False) for j, idx in enumerate(random_idx): inp = dataset[idx][0].unsqueeze(0).to(self.device) convert_to_img(inp).save(path / f'random{j}_raw.png') if not self.model.transformer.is_identity: convert_to_img(self.model.transform(inp)[0, k]).save( path / f'random{j}_tsf.png') convert_to_img( self.model.transform(inp, inverse=True)[0, k]).save( path / f'random{j}_tsf_inp.png') try: convert_to_img(averages[k].avg).save(path / 'avg.png') except AssertionError: print_warning(f'no image found in cluster {k}') @torch.no_grad() def quantitative_eval(self): """Routine to save quantitative results: loss + scores""" loss = AverageMeter() scores_path = self.run_dir / FINAL_SCORES_FILE scores = Scores(self.n_classes, self.n_prototypes) with open(scores_path, mode="w") as f: f.write("loss\t" + "\t".join(scores.names) + "\n") dataset = get_dataset(self.dataset_name)("train", eval_mode=True, **self.dataset_kwargs) loader = DataLoader(dataset, batch_size=self.batch_size, num_workers=self.n_workers) for images, labels in loader: images = images.to(self.device) distances = self.model(images)[1] dist_min_by_sample, argmin_idx = distances.min(1) loss.update(dist_min_by_sample.mean(), n=len(dist_min_by_sample)) scores.update(labels.long().numpy(), argmin_idx.cpu().numpy()) scores = scores.compute() self.print_and_log_info("final_loss: {:.5}".format(loss.avg)) self.print_and_log_info( "final_scores: " + ", ".join(["{}={:.4f}".format(k, v) for k, v in scores.items()])) with open(scores_path, mode="a") as f: f.write("{:.5}\t".format(loss.avg) + "\t".join(map("{:.4f}".format, scores.values())) + "\n")
class VisdomPlotter: """ A Visdom based plotter, to plot aggregated metrics. How to use: ------------ (1) Start the server with: python -m visdom.server (2) Then, in your browser, you can go to: http://localhost:8097 """ def __init__(self, experiment_env, server='http://localhost', port=8097): self.server = server self.port = port self.viz = Visdom( server=server, port=port) # Connect to Visdom server on server / port if not self.start_visdom_server(): raise ValueError('Failed to launch Visdom server at %r:%r' % (server, port)) if experiment_env in self.viz.get_env_list(): self.viz.delete_env( experiment_env) # Clear previous runs with same id self.experiment_env = experiment_env self.plots = {} def start_visdom_server(self): is_visdom_server_connected = self.viz.check_connection( timeout_seconds=1) # Ping if it's already on.. if not is_visdom_server_connected: interpreter_path = sys.executable os.system(interpreter_path + ' -m visdom.server &') is_visdom_server_connected = self.viz.check_connection( timeout_seconds=35) return is_visdom_server_connected def plot_single_metric(self, metric, line_id, title, epoch, value): if metric not in self.plots: self.plots[metric] = self.viz.line(X=np.array([epoch, epoch]), Y=np.array([value, value]), env=self.experiment_env, opts=dict(legend=[line_id], title=title, xlabel='Epochs', ylabel=metric)) else: self.viz.line(X=np.array([epoch]), Y=np.array([value]), env=self.experiment_env, win=self.plots[metric], name=line_id, update='append') def plot_confusion_matrix(self, metric, matrix, label_classes): if metric not in self.plots: self.plots[metric] = self.viz.heatmap( X=matrix, env=self.experiment_env, opts=dict(columnnames=label_classes, rownames=label_classes)) else: self.viz.heatmap(X=matrix, env=self.experiment_env, win=self.plots[metric], opts=dict(columnnames=label_classes, rownames=label_classes)) def plot_images(self, images_bchw): self.viz.images(images_bchw) def plot_aggregated_metrics(self, metrics, epoch): for metric in metrics.metrics: title = metrics.metric_to_title[metric] value = metrics[epoch][metric] if metric == 'confusion_matrix': label_classes = metrics.label_classes self.plot_confusion_matrix(metric, value, label_classes) else: if hasattr(value, 'shape') and value.size > 1: for idx, dim_val in enumerate(value): line_id = metrics.label_classes[idx] self.plot_single_metric(metric, line_id, title, epoch, dim_val) else: line_id = metrics.data_type self.plot_single_metric(metric, line_id, title, epoch, value)
'num_epochs': 1000, 'steps_per_epoch': 20, 'optimizer': 'SGD', 'train_identifier': 'DEBUG'}) device = 'cuda:0' """Visdom stuff""" from visdom import Visdom visdom_log_path = os.path.join("/tmp") #visdom_log_path = outdir print("Saving visdom logs to", visdom_log_path) viz = Visdom(port=6065, log_to_filename=visdom_log_path) # for env in viz.get_env_list(): viz.delete_env(params.train_identifier) viz.log_to_filename = os.path.join(visdom_log_path,params.train_identifier+".visdom") plotter = VisdomLinePlotter(env_name=params.train_identifier, plot_path=visdom_log_path) """ Load the model and make sure it works """ #inputs = torch.zeros([2, 3, 1920, 1184]) model = smp.Unet('resnet34', classes=params.numclasses+2, encoder_weights='imagenet') model = model.to(device) #output = model.forward(inputs) #print("output-shape:", output.shape) preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder_name = 'resnet34', pretrained='imagenet')