class VisdomLinePlotter(object): """Plots to Visdom""" def __init__(self, env_name='main'): self.viz = Visdom() self.env = env_name self.plots = {} def plot(self, var_name, split_name, x, y, env=None): if env is not None: print_env = env else: print_env = self.env if var_name not in self.plots: self.plots[var_name] = self.viz.line(X=np.array([x, x]), Y=np.array([y, y]), env=print_env, opts=dict(legend=[split_name], title=var_name, xlabel='Epochs', ylabel=var_name)) else: self.viz.updateTrace(X=np.array([x]), Y=np.array([y]), env=print_env, win=self.plots[var_name], name=split_name) def plot_mask(self, masks, epoch): self.viz.bar(X=masks, env=self.env, opts=dict( stacked=True, title=epoch, ))
class Vis: def __init__(self, env='default'): self.vis = Visdom(env=env) self.lines = {} self.win_action_bar = None def update(self, x, y, line_name, **kwargs): X = np.array([x]) Y = np.array([y]) if line_name not in self.lines.keys(): self.lines[line_name] = self.vis.line(X=X, Y=Y, opts=dict(title=line_name, **kwargs)) self.vis.line(X=X, Y=Y, win=self.lines[line_name], update='append', opts=dict(title=line_name, **kwargs)) def show_action(self, action): if self.win_action_bar is None: self.win_action_bar = self.vis.bar( X=action, opts=dict(rawnames=['line_vel', 'angle_vel'])) self.win_action_bar = self.vis.bar( X=action, win=self.win_action_bar, opts=dict(rawnames=['line_vel', 'angle_vel']))
class VisdomDictPlotter(object): """Plots to Visdom""" def __init__(self, env_name='main', port=8097): self.vis = Visdom(port=port) self.env = env_name self.plots = {} def plot(self, var_name, keys_name, title_name, dct): keys = list(dct.keys()) values = list(dct.values()) if var_name not in self.plots: self.plots[var_name] = self.vis.bar(X=np.array(values), env=self.env, opts=dict(rownames=keys, title=title_name, xlabel=keys_name, ylabel=var_name)) else: self.vis.bar(X=np.array(values), env=self.env, win=self.plots[var_name], opts=dict(rownames=keys, title=title_name, xlabel=keys_name, ylabel=var_name))
def main(): viz = Visdom(port=2337, env='different classifiers') acc_list = [] model_name_list = list(available_models_input_size.keys()) for model_name in model_name_list: _, *acc = validate_model('../datasets/data0229', model_name=model_name, viz=viz, num_epochs=80, model_dir=None, script_dir=None, feature_extract=False, learning_rates=[1e-5], weight_decays=[1e-4]) acc_list.append(acc) viz.bar(X=acc_list, opts=dict(rownames=model_name_list, legend=['val acc', 'test acc']))
def plot(self, viz: visdom.Visdom): def strongest_correlation(coef_vars_lags: dict): values = list(coef_vars_lags.values()) keys = list(coef_vars_lags.keys()) accumulated_per_variable = np.sum(np.abs(values), axis=1) strongest_id = np.argmax(accumulated_per_variable) return keys[strongest_id], values[strongest_id] acf_variables = {} ccf_variable_pairs = {} for name, samples in self.samples.items(): if len(samples) < self.n_lags + 1: continue observations = torch.stack(samples, dim=0) observations.t_() observations = observations.numpy() active_rows_mask = list(map(np.any, np.diff(observations, axis=1))) active_rows = np.where(active_rows_mask)[0] for i, active_row in enumerate(active_rows): acf_lags = acf(observations[active_row], unbiased=False, nlags=self.n_lags, fft=True, missing='raise') acf_variables[f'{name}.{active_row}'] = acf_lags if self.with_cross_correlation: for paired_row in active_rows[i + 1:]: ccf_lags = ccf(observations[active_row], observations[paired_row], unbiased=False) ccf_variable_pairs[(active_row, paired_row)] = ccf_lags if len(acf_variables) > 0: acf_mean = np.mean(list(acf_variables.values()), axis=0) viz.bar(X=acf_mean, win='autocorr', opts=dict(xlabel='Lag', ylabel='ACF', title=f'mean Autocorrelation')) if len(ccf_variable_pairs) > 0: shortest_length = min(map(len, ccf_variable_pairs.values())) for key, values in ccf_variable_pairs.items(): ccf_variable_pairs[key] = values[:shortest_length] ccf_mean = np.mean(list(ccf_variable_pairs.values()), axis=0) viz.bar(X=ccf_mean, win='crosscorr', opts=dict(xlabel='Lag', ylabel='CCF', ytickmin=0., ytickmax=1., title=f'mean Cross-Correlation'))
class VisdomLinePlotter(object): """Plots to Visdom""" def __init__(self, env_name='main'): self.viz = Visdom() self.env = env_name self.plots = {} def plot(self, var_name, split_name, x, y, exp_name='test', env=None): if env is not None: print_env = env else: print_env = self.env if var_name not in self.plots: self.plots[var_name] = self.viz.line(X=np.array([x,x]), Y=np.array([y,y]), env=print_env, opts=dict( legend=[split_name], title=var_name, xlabel='Epochs', ylabel=var_name )) else: self.viz.updateTrace(X=np.array([x]), Y=np.array([y]), env=print_env, win=self.plots[var_name], name=split_name) if not os.path.exists('runs/%s/data/'%(exp_name)): os.makedirs('runs/%s/data/'%(exp_name)) file = open('runs/%s/data/%s_%s_data.csv'%(exp_name, split_name, var_name), 'a') file.write('%d, %f\n'%(x, y)) file.close() def plot_mask(self, masks, epoch): self.viz.bar( X=masks, env=self.env, opts=dict( stacked=True, title=epoch, ) ) def plot_image(self, image, epoch, exp_name='test'): self.viz.image(image, env=exp_name+'_img', opts=dict( caption=epoch, )) def plot_images(self, images, run_split, epoch, nrow, padding=2, exp_name='test'): self.viz.images(images, env=exp_name+'_img', nrow=nrow, padding=padding, opts=dict( caption='%s_%d'%(run_split, epoch), # title='Random images', jpgquality=100, ))
class VisdomSummary(object): def __init__(self, port=None, env=None): self.vis = Visdom(port=port, env=env) self.opts = dict() def scalar(self, win, name, x, y, remove=False): if not hasattr(self, '__scalar'): self.__scalar = Scalar() opts = dict(title=win) self.__scalar.update(self.vis, win, name, x, y, opts=opts, remove=remove) def bar(self, win, x, rownames=None): if rownames is None: rownames = ['{}'.format(i) for i in range(x.size(0))] opts = dict(title=win, rownames=rownames) self.vis.bar(X=x, win=win, opts=opts) def image2d(self, win, name, img, caption=None, nrow=3): if not hasattr(self, '__image2d'): self.__image2d = Image2D() self.__image2d.update(self.vis, win, name, img, caption, nrow) def image3d(self, win, name, img): raise NotImplementedError def text(self, win, text): self.vis.text(text, win=win) def close(self, win=None): self.vis.close(win=win) def save(self): self.vis.save([self.vis.env])
count += 1 # ------------------------------------------------------------- score_text = vis.text("Model scoring") if is_labeled_data and (alg.type == 'classification' or alg.type == 'clustering'): if model_fpr is not None and model_tpr is not None: vis.line(X=model_fpr, Y=model_tpr, opts=dict(xlabel='False Positive Rate', ylabel='True Positive Rate', title='ROC Curve')) vis.bar(X=model_cm, opts=dict(stacked=True, legend=classes, rownames=classes, title='Predictive model performance')) vis.text("Classification scores", win=score_text, append=True) vis.text("Accuracy score:", win=score_text, append=True) vis.text(str(model_as), win=score_text, append=True) vis.text("Precision score:", win=score_text, append=True) vis.text(str(model_ps), win=score_text, append=True) vis.text("Confusion matrix:", win=score_text, append=True) vis.text(str(model_cm), win=score_text, append=True) if is_labeled_data and alg.type == 'regression': vis.text("Regression scores", win=score_text, append=True) vis.text("Mean squared error:", win=score_text, append=True) vis.text(str(model_mse), win=score_text, append=True) vis.text("Mean absolute error:", win=score_text, append=True)
class Trainer: """Pipeline to train a NN model using a certain dataset, both specified by an YML config.""" @use_seed() def __init__(self, config_path, run_dir): self.config_path = coerce_to_path_and_check_exist(config_path) self.run_dir = coerce_to_path_and_create_dir(run_dir) self.logger = get_logger(self.run_dir, name="trainer") self.print_and_log_info( "Trainer initialisation: run directory is {}".format(run_dir)) shutil.copy(self.config_path, self.run_dir) self.print_and_log_info("Config {} copied to run directory".format( self.config_path)) with open(self.config_path) as fp: cfg = yaml.load(fp, Loader=yaml.FullLoader) if torch.cuda.is_available(): type_device = "cuda" nb_device = torch.cuda.device_count() else: type_device = "cpu" nb_device = None self.device = torch.device(type_device) self.print_and_log_info("Using {} device, nb_device is {}".format( type_device, nb_device)) # Datasets and dataloaders self.dataset_kwargs = cfg["dataset"] self.dataset_name = self.dataset_kwargs.pop("name") train_dataset = get_dataset(self.dataset_name)("train", **self.dataset_kwargs) val_dataset = get_dataset(self.dataset_name)("val", **self.dataset_kwargs) self.n_classes = train_dataset.n_classes self.is_val_empty = len(val_dataset) == 0 self.print_and_log_info("Dataset {} instantiated with {}".format( self.dataset_name, self.dataset_kwargs)) self.print_and_log_info( "Found {} classes, {} train samples, {} val samples".format( self.n_classes, len(train_dataset), len(val_dataset))) self.img_size = train_dataset.img_size self.batch_size = cfg["training"]["batch_size"] self.n_workers = cfg["training"].get("n_workers", 4) self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size, num_workers=self.n_workers, shuffle=True) self.val_loader = DataLoader(val_dataset, batch_size=self.batch_size, num_workers=self.n_workers) self.print_and_log_info( "Dataloaders instantiated with batch_size={} and n_workers={}". format(self.batch_size, self.n_workers)) self.n_batches = len(self.train_loader) self.n_iterations, self.n_epoches = cfg["training"].get( "n_iterations"), cfg["training"].get("n_epoches") assert not (self.n_iterations is not None and self.n_epoches is not None) if self.n_iterations is not None: self.n_epoches = max(self.n_iterations // self.n_batches, 1) else: self.n_iterations = self.n_epoches * len(self.train_loader) # Model self.model_kwargs = cfg["model"] self.model_name = self.model_kwargs.pop("name") self.is_gmm = 'gmm' in self.model_name self.model = get_model(self.model_name)( self.train_loader.dataset, **self.model_kwargs).to(self.device) self.print_and_log_info("Using model {} with kwargs {}".format( self.model_name, self.model_kwargs)) self.print_and_log_info('Number of trainable parameters: {}'.format( f'{count_parameters(self.model):,}')) self.n_prototypes = self.model.n_prototypes # Optimizer opt_params = cfg["training"]["optimizer"] or {} optimizer_name = opt_params.pop("name") cluster_kwargs = opt_params.pop('cluster', {}) tsf_kwargs = opt_params.pop('transformer', {}) self.optimizer = get_optimizer(optimizer_name)([ dict(params=self.model.cluster_parameters(), **cluster_kwargs), dict(params=self.model.transformer_parameters(), **tsf_kwargs) ], **opt_params) self.model.set_optimizer(self.optimizer) self.print_and_log_info("Using optimizer {} with kwargs {}".format( optimizer_name, opt_params)) self.print_and_log_info("cluster kwargs {}".format(cluster_kwargs)) self.print_and_log_info("transformer kwargs {}".format(tsf_kwargs)) # Scheduler scheduler_params = cfg["training"].get("scheduler", {}) or {} scheduler_name = scheduler_params.pop("name", None) self.scheduler_update_range = scheduler_params.pop( "update_range", "epoch") assert self.scheduler_update_range in ["epoch", "batch"] if scheduler_name == "multi_step" and isinstance( scheduler_params["milestones"][0], float): n_tot = self.n_epoches if self.scheduler_update_range == "epoch" else self.n_iterations scheduler_params["milestones"] = [ round(m * n_tot) for m in scheduler_params["milestones"] ] self.scheduler = get_scheduler(scheduler_name)(self.optimizer, **scheduler_params) self.cur_lr = self.scheduler.get_last_lr()[0] self.print_and_log_info("Using scheduler {} with parameters {}".format( scheduler_name, scheduler_params)) # Pretrained / Resume checkpoint_path = cfg["training"].get("pretrained") checkpoint_path_resume = cfg["training"].get("resume") assert not (checkpoint_path is not None and checkpoint_path_resume is not None) if checkpoint_path is not None: self.load_from_tag(checkpoint_path) elif checkpoint_path_resume is not None: self.load_from_tag(checkpoint_path_resume, resume=True) else: self.start_epoch, self.start_batch = 1, 1 # Train metrics & check_cluster interval metric_names = ['time/img', 'loss'] metric_names += [f'prop_clus{i}' for i in range(self.n_prototypes)] train_iter_interval = cfg["training"]["train_stat_interval"] self.train_stat_interval = train_iter_interval self.train_metrics = Metrics(*metric_names) self.train_metrics_path = self.run_dir / TRAIN_METRICS_FILE with open(self.train_metrics_path, mode="w") as f: f.write("iteration\tepoch\tbatch\t" + "\t".join(self.train_metrics.names) + "\n") self.check_cluster_interval = cfg["training"]["check_cluster_interval"] # Val metrics & scores val_iter_interval = cfg["training"]["val_stat_interval"] self.val_stat_interval = val_iter_interval self.val_metrics = Metrics('loss_val') self.val_metrics_path = self.run_dir / VAL_METRICS_FILE with open(self.val_metrics_path, mode="w") as f: f.write("iteration\tepoch\tbatch\t" + "\t".join(self.val_metrics.names) + "\n") self.val_scores = Scores(self.n_classes, self.n_prototypes) self.val_scores_path = self.run_dir / VAL_SCORES_FILE with open(self.val_scores_path, mode="w") as f: f.write("iteration\tepoch\tbatch\t" + "\t".join(self.val_scores.names) + "\n") # Prototypes & Variances self.prototypes_path = coerce_to_path_and_create_dir(self.run_dir / 'prototypes') [ coerce_to_path_and_create_dir(self.prototypes_path / f'proto{k}') for k in range(self.n_prototypes) ] if self.is_gmm: self.variances_path = coerce_to_path_and_create_dir(self.run_dir / 'variances') [ coerce_to_path_and_create_dir(self.variances_path / f'var{k}') for k in range(self.n_prototypes) ] # Transformation predictions self.transformation_path = coerce_to_path_and_create_dir( self.run_dir / 'transformations') self.images_to_tsf = next(iter( self.train_loader))[0][:N_TRANSFORMATION_PREDICTIONS].to( self.device) for k in range(self.images_to_tsf.size(0)): out = coerce_to_path_and_create_dir(self.transformation_path / f'img{k}') convert_to_img(self.images_to_tsf[k]).save(out / 'input.png') [ coerce_to_path_and_create_dir(out / f'tsf{k}') for k in range(self.n_prototypes) ] # Visdom viz_port = cfg["training"].get("visualizer_port") if viz_port is not None: from visdom import Visdom os.environ["http_proxy"] = "" self.visualizer = Visdom( port=viz_port, env=f'{self.run_dir.parent.name}_{self.run_dir.name}') self.visualizer.delete_env( self.visualizer.env) # Clean env before plotting self.print_and_log_info(f"Visualizer initialised at {viz_port}") else: self.visualizer = None self.print_and_log_info("No visualizer initialized") def print_and_log_info(self, string): print_info(string) self.logger.info(string) def load_from_tag(self, tag, resume=False): self.print_and_log_info("Loading model from run {}".format(tag)) path = coerce_to_path_and_check_exist(RUNS_PATH / self.dataset_name / tag / MODEL_FILE) checkpoint = torch.load(path, map_location=self.device) try: self.model.load_state_dict(checkpoint["model_state"]) except RuntimeError: state = safe_model_state_dict(checkpoint["model_state"]) self.model.module.load_state_dict(state) self.start_epoch, self.start_batch = 1, 1 if resume: self.start_epoch, self.start_batch = checkpoint[ "epoch"], checkpoint.get("batch", 0) + 1 self.optimizer.load_state_dict(checkpoint["optimizer_state"]) self.scheduler.load_state_dict(checkpoint["scheduler_state"]) self.cur_lr = self.scheduler.get_last_lr()[0] self.print_and_log_info( "Checkpoint loaded at epoch {}, batch {}".format( self.start_epoch, self.start_batch - 1)) @property def score_name(self): return self.val_scores.score_name def print_memory_usage(self, prefix): usage = {} for attr in [ "memory_allocated", "max_memory_allocated", "memory_cached", "max_memory_cached" ]: usage[attr] = getattr(torch.cuda, attr)() * 0.000001 self.print_and_log_info("{}:\t{}".format( prefix, " / ".join( ["{}: {:.0f}MiB".format(k, v) for k, v in usage.items()]))) @use_seed() def run(self): cur_iter = (self.start_epoch - 1) * self.n_batches + self.start_batch - 1 prev_train_stat_iter, prev_val_stat_iter = cur_iter, cur_iter prev_check_cluster_iter = cur_iter if self.start_epoch == self.n_epoches: self.print_and_log_info("No training, only evaluating") self.evaluate() self.print_and_log_info("Training run is over") return None for epoch in range(self.start_epoch, self.n_epoches + 1): batch_start = self.start_batch if epoch == self.start_epoch else 1 for batch, (images, labels) in enumerate(self.train_loader, start=1): if batch < batch_start: continue cur_iter += 1 if cur_iter > self.n_iterations: break self.single_train_batch_run(images) if self.scheduler_update_range == "batch": self.update_scheduler(epoch, batch=batch) if (cur_iter - prev_train_stat_iter) >= self.train_stat_interval: prev_train_stat_iter = cur_iter self.log_train_metrics(cur_iter, epoch, batch) if (cur_iter - prev_check_cluster_iter ) >= self.check_cluster_interval: prev_check_cluster_iter = cur_iter self.check_cluster(cur_iter, epoch, batch) if (cur_iter - prev_val_stat_iter) >= self.val_stat_interval: prev_val_stat_iter = cur_iter if not self.is_val_empty: self.run_val() self.log_val_metrics(cur_iter, epoch, batch) self.log_images(cur_iter) self.save(epoch=epoch, batch=batch) self.model.step() if self.scheduler_update_range == "epoch" and batch_start == 1: self.update_scheduler(epoch + 1, batch=1) self.save_training_metrics() self.evaluate() self.print_and_log_info("Training run is over") def update_scheduler(self, epoch, batch): self.scheduler.step() lr = self.scheduler.get_last_lr()[0] if lr != self.cur_lr: self.cur_lr = lr self.print_and_log_info( PRINT_LR_UPD_FMT(epoch, self.n_epoches, batch, self.n_batches, lr)) def single_train_batch_run(self, images): start_time = time.time() B = images.size(0) self.model.train() images = images.to(self.device) self.optimizer.zero_grad() loss, distances = self.model(images) loss.backward() self.optimizer.step() with torch.no_grad(): argmin_idx = distances.min(1)[1] mask = torch.zeros(B, self.n_prototypes, device=self.device).scatter( 1, argmin_idx[:, None], 1) proportions = mask.sum(0).cpu().numpy() / B self.train_metrics.update({ 'time/img': (time.time() - start_time) / B, 'loss': loss.item(), }) self.train_metrics.update( {f'prop_clus{i}': p for i, p in enumerate(proportions)}) @torch.no_grad() def log_images(self, cur_iter): self.save_prototypes(cur_iter) self.update_visualizer_images(self.model.prototypes, 'prototypes', nrow=5) if self.is_gmm: self.save_variances(cur_iter) variances = self.model.variances M = variances.flatten(1).max(1)[0][:, None, None, None] variances = (variances - self.model.var_min) / (M - self.model.var_min + 1e-7) self.update_visualizer_images(variances, 'variances', nrow=5) tsf_imgs = self.save_transformed_images(cur_iter) C, H, W = tsf_imgs.shape[2:] self.update_visualizer_images(tsf_imgs.view(-1, C, H, W), 'transformations', nrow=self.n_prototypes + 1) def save_prototypes(self, cur_iter=None): prototypes = self.model.prototypes for k in range(self.n_prototypes): img = convert_to_img(prototypes[k]) if cur_iter is not None: img.save(self.prototypes_path / f'proto{k}' / f'{cur_iter}.jpg') else: img.save(self.prototypes_path / f'prototype{k}.png') def save_variances(self, cur_iter=None): variances = self.model.variances for k in range(self.n_prototypes): img = convert_to_img(variances[k]) if cur_iter is not None: img.save(self.variances_path / f'var{k}' / f'{cur_iter}.jpg') else: img.save(self.variances_path / f'variance{k}.png') @torch.no_grad() def save_transformed_images(self, cur_iter=None): self.model.eval() output = self.model.transform(self.images_to_tsf) transformed_imgs = torch.cat([self.images_to_tsf.unsqueeze(1), output], 1) for k in range(transformed_imgs.size(0)): for j, img in enumerate(transformed_imgs[k][1:]): if cur_iter is not None: convert_to_img(img).save(self.transformation_path / f'img{k}' / f'tsf{j}' / f'{cur_iter}.jpg') else: convert_to_img(img).save(self.transformation_path / f'img{k}' / f'tsf{j}.png') return transformed_imgs def update_visualizer_images(self, images, title, nrow): if self.visualizer is None: return None if max(images.shape[1:]) > VIZ_MAX_IMG_SIZE: images = torch.nn.functional.interpolate(images, size=VIZ_MAX_IMG_SIZE, mode='bilinear') self.visualizer.images(images.clamp(0, 1), win=title, nrow=nrow, opts=dict(title=title, store_history=True, width=VIZ_WIDTH, height=VIZ_HEIGHT)) def check_cluster(self, cur_iter, epoch, batch): proportions = [ self.train_metrics[f'prop_clus{i}'].avg for i in range(self.n_prototypes) ] reassigned, idx = self.model.reassign_empty_clusters(proportions) msg = PRINT_CHECK_CLUSTERS_FMT(epoch, self.n_epoches, batch, self.n_batches, reassigned, idx) self.print_and_log_info(msg) self.train_metrics.reset( *[f'prop_clus{i}' for i in range(self.n_prototypes)]) def log_train_metrics(self, cur_iter, epoch, batch): # Print & write metrics to file stat = PRINT_TRAIN_STAT_FMT(epoch, self.n_epoches, batch, self.n_batches, self.train_metrics) self.print_and_log_info(stat) with open(self.train_metrics_path, mode="a") as f: f.write("{}\t{}\t{}\t".format(cur_iter, epoch, batch) + "\t".join( map("{:.4f}".format, self.train_metrics.avg_values)) + "\n") self.update_visualizer_metrics(cur_iter, train=True) self.train_metrics.reset('time/img', 'loss') def update_visualizer_metrics(self, cur_iter, train): if self.visualizer is None: return None split, metrics = ('train', self.train_metrics) if train else ('val', self.val_metrics) loss_names = [n for n in metrics.names if 'loss' in n] y, x = [[metrics[n].avg for n in loss_names]], [[cur_iter] * len(loss_names)] self.visualizer.line(y, x, win=f'{split}_loss', update='append', opts=dict(title=f'{split}_loss', legend=loss_names, width=VIZ_WIDTH, height=VIZ_HEIGHT)) if train: if self.n_prototypes > 1: # Cluster proportions proportions = [ metrics[f'prop_clus{i}'].avg for i in range(self.n_prototypes) ] self.visualizer.bar(proportions, win=f'train_cluster_prop', opts=dict( title=f'train_cluster_proportions', width=VIZ_HEIGHT, height=VIZ_HEIGHT)) else: # Scores names = list( filter(lambda name: 'cls' not in name, self.val_scores.names)) y, x = [[self.val_scores[n] for n in names]], [[cur_iter] * len(names)] self.visualizer.line(y, x, win=f'global_scores', update='append', opts=dict(title=f'global_scores', legend=names, width=VIZ_WIDTH, height=VIZ_HEIGHT)) y, x = [[ self.val_scores[f'acc_cls{i}'] for i in range(self.n_classes) ]], [[cur_iter] * self.n_classes] self.visualizer.line( y, x, win=f'acc_by_cls', update='append', opts=dict(title=f'acc_by_cls', legend=[f'cls{i}' for i in range(self.n_classes)], width=VIZ_WIDTH, heigh=VIZ_HEIGHT)) @torch.no_grad() def run_val(self): self.model.eval() for images, labels in self.val_loader: images = images.to(self.device) distances = self.model(images)[1] dist_min_by_sample, argmin_idx = distances.min(1) loss_val = dist_min_by_sample.mean() self.val_metrics.update({'loss_val': loss_val.item()}) self.val_scores.update(labels.long().numpy(), argmin_idx.cpu().numpy()) def log_val_metrics(self, cur_iter, epoch, batch): stat = PRINT_VAL_STAT_FMT(epoch, self.n_epoches, batch, self.n_batches, self.val_metrics) self.print_and_log_info(stat) with open(self.val_metrics_path, mode="a") as f: f.write( "{}\t{}\t{}\t".format(cur_iter, epoch, batch) + "\t".join(map("{:.4f}".format, self.val_metrics.avg_values)) + "\n") scores = self.val_scores.compute() self.print_and_log_info( "val_scores: " + ", ".join(["{}={:.4f}".format(k, v) for k, v in scores.items()])) with open(self.val_scores_path, mode="a") as f: f.write("{}\t{}\t{}\t".format(cur_iter, epoch, batch) + "\t".join(map("{:.4f}".format, scores.values())) + "\n") self.update_visualizer_metrics(cur_iter, train=False) self.val_scores.reset() self.val_metrics.reset() def save(self, epoch, batch): state = { "epoch": epoch, "batch": batch, "model_name": self.model_name, "model_kwargs": self.model_kwargs, "model_state": self.model.state_dict(), "n_prototypes": self.n_prototypes, "optimizer_state": self.optimizer.state_dict(), "scheduler_state": self.scheduler.state_dict(), } save_path = self.run_dir / MODEL_FILE torch.save(state, save_path) self.print_and_log_info("Model saved at {}".format(save_path)) def save_training_metrics(self): df_train = pd.read_csv(self.train_metrics_path, sep="\t", index_col=0) df_val = pd.read_csv(self.val_metrics_path, sep="\t", index_col=0) df_scores = pd.read_csv(self.val_scores_path, sep="\t", index_col=0) if len(df_train) == 0: self.print_and_log_info("No metrics or plots to save") return # Losses losses = list( filter(lambda s: s.startswith('loss'), self.train_metrics.names)) df = df_train.join(df_val[['loss_val']], how="outer") fig = plot_lines(df, losses + ['loss_val'], title="Loss") fig.savefig(self.run_dir / "loss.pdf") # Cluster proportions names = list( filter(lambda s: s.startswith('prop_'), self.train_metrics.names)) fig = plot_lines(df, names, title="Cluster proportions") fig.savefig(self.run_dir / "cluster_proportions.pdf") s = df[names].iloc[-1] s.index = list(map(lambda n: n.replace('prop_clus', ''), names)) fig = plot_bar(s, title="Final cluster proportions") fig.savefig(self.run_dir / "cluster_proportions_final.pdf") # Validation if not self.is_val_empty: names = list( filter(lambda name: 'cls' not in name, self.val_scores.names)) fig = plot_lines(df_scores, names, title="Global scores", unit_yaxis=True) fig.savefig(self.run_dir / 'global_scores.pdf') fig = plot_lines(df_scores, [f'acc_cls{i}' for i in range(self.n_classes)], title="Scores by cls", unit_yaxis=True) fig.savefig(self.run_dir / "scores_by_cls.pdf") # Prototypes & Variances size = MAX_GIF_SIZE if MAX_GIF_SIZE < max( self.img_size) else self.img_size self.save_prototypes() if self.is_gmm: self.save_variances() for k in range(self.n_prototypes): save_gif(self.prototypes_path / f'proto{k}', f'prototype{k}.gif', size=size) shutil.rmtree(str(self.prototypes_path / f'proto{k}')) if self.is_gmm: save_gif(self.variances_path / f'var{k}', f'variance{k}.gif', size=size) shutil.rmtree(str(self.variances_path / f'var{k}')) # Transformation predictions if self.model.transformer.is_identity: # no need to keep transformation predictions shutil.rmtree(str(self.transformation_path)) coerce_to_path_and_create_dir(self.transformation_path) else: self.save_transformed_images() for i in range(self.images_to_tsf.size(0)): for k in range(self.n_prototypes): save_gif(self.transformation_path / f'img{i}' / f'tsf{k}', f'tsf{k}.gif', size=size) shutil.rmtree( str(self.transformation_path / f'img{i}' / f'tsf{k}')) self.print_and_log_info("Training metrics and visuals saved") def evaluate(self): self.model.eval() no_label = self.train_loader.dataset[0][1] == -1 if no_label: self.qualitative_eval() else: self.quantitative_eval() self.print_and_log_info("Evaluation is over") @torch.no_grad() def qualitative_eval(self): """Routine to save qualitative results""" loss = AverageMeter() scores_path = self.run_dir / FINAL_SCORES_FILE with open(scores_path, mode="w") as f: f.write("loss\n") cluster_path = coerce_to_path_and_create_dir(self.run_dir / 'clusters') dataset = self.train_loader.dataset train_loader = DataLoader(dataset, batch_size=self.batch_size, num_workers=self.n_workers, shuffle=False) # Compute results distances, cluster_idx = np.array([]), np.array([], dtype=np.int32) averages = {k: AverageTensorMeter() for k in range(self.n_prototypes)} for images, labels in train_loader: images = images.to(self.device) dist = self.model(images)[1] dist_min_by_sample, argmin_idx = map(lambda t: t.cpu().numpy(), dist.min(1)) loss.update(dist_min_by_sample.mean(), n=len(dist_min_by_sample)) argmin_idx = argmin_idx.astype(np.int32) distances = np.hstack([distances, dist_min_by_sample]) cluster_idx = np.hstack([cluster_idx, argmin_idx]) transformed_imgs = self.model.transform(images).cpu() for k in range(self.n_prototypes): imgs = transformed_imgs[argmin_idx == k, k] averages[k].update(imgs) self.print_and_log_info("final_loss: {:.5}".format(loss.avg)) with open(scores_path, mode="a") as f: f.write("{:.5}\n".format(loss.avg)) # Save results with open(cluster_path / 'cluster_counts.tsv', mode='w') as f: f.write('\t'.join([str(k) for k in range(self.n_prototypes)]) + '\n') f.write('\t'.join( [str(averages[k].count) for k in range(self.n_prototypes)]) + '\n') for k in range(self.n_prototypes): path = coerce_to_path_and_create_dir(cluster_path / f'cluster{k}') indices = np.where(cluster_idx == k)[0] top_idx = np.argsort(distances[indices])[:N_CLUSTER_SAMPLES] for j, idx in enumerate(top_idx): inp = dataset[indices[idx]][0].unsqueeze(0).to(self.device) convert_to_img(inp).save(path / f'top{j}_raw.png') if not self.model.transformer.is_identity: convert_to_img(self.model.transform(inp)[0, k]).save( path / f'top{j}_tsf.png') convert_to_img( self.model.transform(inp, inverse=True)[0, k]).save( path / f'top{j}_tsf_inp.png') if len(indices) <= N_CLUSTER_SAMPLES: random_idx = indices else: random_idx = np.random.choice(indices, N_CLUSTER_SAMPLES, replace=False) for j, idx in enumerate(random_idx): inp = dataset[idx][0].unsqueeze(0).to(self.device) convert_to_img(inp).save(path / f'random{j}_raw.png') if not self.model.transformer.is_identity: convert_to_img(self.model.transform(inp)[0, k]).save( path / f'random{j}_tsf.png') convert_to_img( self.model.transform(inp, inverse=True)[0, k]).save( path / f'random{j}_tsf_inp.png') try: convert_to_img(averages[k].avg).save(path / 'avg.png') except AssertionError: print_warning(f'no image found in cluster {k}') @torch.no_grad() def quantitative_eval(self): """Routine to save quantitative results: loss + scores""" loss = AverageMeter() scores_path = self.run_dir / FINAL_SCORES_FILE scores = Scores(self.n_classes, self.n_prototypes) with open(scores_path, mode="w") as f: f.write("loss\t" + "\t".join(scores.names) + "\n") dataset = get_dataset(self.dataset_name)("train", eval_mode=True, **self.dataset_kwargs) loader = DataLoader(dataset, batch_size=self.batch_size, num_workers=self.n_workers) for images, labels in loader: images = images.to(self.device) distances = self.model(images)[1] dist_min_by_sample, argmin_idx = distances.min(1) loss.update(dist_min_by_sample.mean(), n=len(dist_min_by_sample)) scores.update(labels.long().numpy(), argmin_idx.cpu().numpy()) scores = scores.compute() self.print_and_log_info("final_loss: {:.5}".format(loss.avg)) self.print_and_log_info( "final_scores: " + ", ".join(["{}={:.4f}".format(k, v) for k, v in scores.items()])) with open(scores_path, mode="a") as f: f.write("{:.5}\t".format(loss.avg) + "\t".join(map("{:.4f}".format, scores.values())) + "\n")
class Visualizer: def __init__(self, env="main", server="http://localhost", port=8097, base_url="/", http_proxy_host=None, http_proxy_port=None, log_to_filename=None): self._viz = Visdom(env=env, server=server, port=port, http_proxy_host=http_proxy_host, http_proxy_port=http_proxy_port, log_to_filename=log_to_filename, use_incoming_socket=False) self._viz.close(env=env) self.plots = {} def plot_line(self, name, tag, title, value, step=None): if name not in self.plots: y = numpy.array([value, value]) if step is not None: x = numpy.array([step, step]) else: x = numpy.array([1, 1]) opts = dict(title=title, xlabel='steps', ylabel=name) if tag is not None: opts["legend"] = [tag] self.plots[name] = self._viz.line(X=x, Y=y, opts=opts) else: y = numpy.array([value]) x = numpy.array([step]) self._viz.line(X=x, Y=y, win=self.plots[name], name=tag, update='append') def plot_text(self, text, title, pre=True): _width = max([len(x) for x in text.split("\n")]) * 10 _heigth = len(text.split("\n")) * 20 _heigth = max(_heigth, 120) if pre: text = "<pre>{}</pre>".format(text) self._viz.text(text, win=title, opts=dict(title=title, width=min(_width, 450), height=min(_heigth, 300))) def plot_bar(self, data, labels, title): self._viz.bar(win=title, X=data, opts=dict(legend=labels, stacked=False, title=title)) def plot_hist(self, data, title, numbins=20): self._viz.histogram(win=title, X=data, opts=dict(numbins=numbins, title=title)) def plot_scatter(self, data, title, targets=None, labels=None): self._viz.scatter( win=title, X=data, Y=targets, opts=dict( # legend=labels, title=title, markersize=5, webgl=True, width=400, height=400, markeropacity=0.5)) def plot_heatmap(self, data, labels, title): height = min(data.shape[0] * 20, 600) width = min(data.shape[1] * 25, 600) self._viz.heatmap( win=title, X=data, opts=dict( # title=title, columnnames=labels[1], rownames=labels[0], width=width, height=height, layoutopts={ 'plotly': { 'showscale': False, 'showticksuffix': False, 'showtickprefix': False, 'xaxis': { 'side': 'top', 'tickangle': -60, # 'autorange': "reversed" }, 'yaxis': { 'autorange': "reversed" }, } }))
temp = np.zeros((len(Configure.errorParameter), 3)) temp[:, 0] = np.array( errorRateInformationConvex[id - len(Configure.errorParameter):id] )[:, 1] temp[:, 1] = np.array( errorRateInformationGreedy[id - len(Configure.errorParameter):id] )[:, 1] temp[:, 2] = np.array( errorRateInformationLinear[id - len(Configure.errorParameter):id] )[:, 1] viz.bar(X=temp, opts=dict( stacked=False, legend=['SLSQP', 'Greedy', 'Ours'], rownames=[str(i) for i in Configure.errorParameter], )) viz.text(str(_dataParameter)) if Configure.mulThread > 0: requestList = threadpool.makeRequests(experiment, allArgList) [pool.putRequest(req) for req in requestList] pool.wait() # sort if Configure.mulThread > 0: sorted(errorRateInformationLinear, key=lambda x: x[0]) sorted(errorRateInformationGreedy, key=lambda x: x[0]) sorted(errorRateInformationConvex, key=lambda x: x[0]) sorted(timeLinear, key=lambda x: x[0])
class VisdomController(): def __init__(self): self.vis = Visdom() self.connected = self.vis.check_connection() self.plots = {} def ClearPlots(self): self.plots = {} def IsConnected(self): return self.connected def CreateLinePlot(self, x, y, title, xlabel, ylabel, win, key, env="main"): self.plots[win] = self.vis.line(X=np.array([x, x]), Y=np.array([y, y]), env=env, opts=dict(title=title, xlabel=xlabel, ylabel=ylabel, win=win, legend=[key], showlegend=True)) def CreateStaticLinePlot(self, x, y, title, xlabel, ylabel, win, key, env="main"): self.plots[win] = self.vis.line(X=x, Y=y, env=env, opts=dict(title=title, xlabel=xlabel, ylabel=ylabel, win=win, legend=[key], showlegend=True)) def CreateScatterPlot(self, data, title, xlabel, ylabel, win, env="main"): self.plots[win] = self.vis.scatter(data, env=env, opts=dict(title=title, xlabel=xlabel, ylabel=ylabel, win=win)) def CreateStaticBarPlot(self, x, y, title, xlabel, ylabel, win, env="main"): self.plots[win] = self.vis.bar(X=x, Y=y, env=env, opts=dict(title=title, xlabel=xlabel, ylabel=ylabel, win=win)) def UpdateLinePlot(self, x, y, win, key, env="main"): self.vis.line(np.array([y]), X=np.array([x]), env=env, win=self.plots[win], name=key, update="append") def UpdateScatterPlot(self, data, win, env="main"): self.vis.scatter(data, env=env, win=self.plots[win], update="replace") # Custom Plots def PlotLoss(self, key, loss): if not self.IsConnected(): return plot_win = "loss_window" if plot_win not in self.plots: self.loss_axis = 0 self.CreateLinePlot(self.loss_axis, loss, "Loss Graph", "T", "Loss", plot_win, key) else: self.UpdateLinePlot(self.loss_axis, loss, plot_win, key) def PlotFakeFeatureDistributionComparison(self, f_idx_0, f_idx_1, gen_nn, batch_size, noise_function, class_to_mimic): if not self.IsConnected(): return fake_data_0 = synthesize_data_from_label( gen_nn, batch_size, noise_function, class_to_mimic).view(batch_size, -1).detach().cpu().numpy()[:, f_idx_0] fake_data_1 = synthesize_data_from_label( gen_nn, batch_size, noise_function, class_to_mimic).view(batch_size, -1).detach().cpu().numpy()[:, f_idx_1] data = np.array([fake_data_0, fake_data_1]).T plot_win = str(f_idx_0) + str(f_idx_1) + "_fake_comp_window" title = "fake features : " + str(f_idx_0) + " vs " + str(f_idx_1) env = "feature_comparison" if plot_win not in self.plots: self.CreateScatterPlot(data, title, str(f_idx_0), str(f_idx_1), plot_win, env) else: self.UpdateScatterPlot(data, plot_win, env) def PlotRealFeatureDistributionComparison(self, f_idx_0, f_idx_1, real_data, num_samples): if not self.IsConnected(): return rows = np.random.choice(np.arange(0, real_data.shape[0]), size=num_samples, replace=False) data = real_data.detach().cpu().numpy()[:, [f_idx_0, f_idx_1]][rows, :] plot_win = str(f_idx_0) + str(f_idx_1) + "_real_comp_window" title = "real features : " + str(f_idx_0) + " vs " + str(f_idx_1) env = "feature_comparison" if plot_win not in self.plots: self.CreateScatterPlot(data, title, str(f_idx_0), str(f_idx_1), plot_win, env) else: self.UpdateScatterPlot(data, plot_win, env) def PlotHeatMap(self, matrix, key, make_lower_triange): print(matrix.shape) if not self.IsConnected(): return plot_win = key if make_lower_triange: matrix[np.tril_indices_from(matrix)] = 0 if plot_win not in self.plots: self.plots[plot_win] = self.vis.heatmap(X=matrix) else: self.vis.heatmap(X=matrix, win=self.plots[plot_win]) def ShowImages(self, imgs, caption): self.vis.images(imgs.unsqueeze(1), opts=dict(nrow=5, caption=caption))
class VisdomLogger(): def __init__(self,**kwargs): if Visdom is None: self.viz = None # do nothing return self.connected = True try: self.viz = Visdom(raise_exceptions=True,**kwargs) except Exception as e: print("Could not reach visdom server...") self.connected = False pass self.windows = dict() r = np.random.RandomState(1) self.colors = r.randint(0,255, size=(255,3)) self.colors[0] = np.array([1., 1., 1.]) self.colors[1] = np.array([0. , 0.18431373, 0.65490196]) # ikb blue def update(self, data): if self.connected: self.plot_epochs(data) def bar(self,X, name="barplot"): if self.connected: X[np.isnan(X)] = 0 win = name.replace(" ","_") opts = dict( title=name, xlabel='t', ylabel="P(t)", width=600, height=200, marginleft=20, marginright=20, marginbottom=20, margintop=30 ) self.viz.bar(X,win=win,opts=opts) def plot(self, X, name="plot",**kwargs): if self.connected: X[np.isnan(X)] = 0 win = "pl_"+name.replace(" ","_") opts = dict( title=name, xlabel='t', ylabel="P(t)", width=600, height=200, marginleft=20, marginright=20, marginbottom=20, margintop=30, **kwargs ) self.viz.line(X ,win=win, opts=opts) def confusion_matrix(self, cm, title="Confusion Matrix", norm=None): if self.connected: plt.clf() if norm is not None: cm /= np.expand_dims(cm.sum(norm),axis=norm) cm[np.isnan(cm)] = 0 cm[np.isinf(cm)] = 0 vmin = 0 vmax = 1 else: vmin = None vmax = None name=title plt.rcParams['figure.figsize'] = (9, 9) #sn.set(font_scale=1.4) # for label size ax = sn.heatmap(cm, annot=True, annot_kws={"size": 11}, vmin=vmin, vmax=vmax) # font size ax.set(xlabel='ground truth', ylabel='predicted', title=title) plt.tight_layout() opts = dict( resizeable=True ) self.viz.matplot(plt, win=name, opts=opts) def plot_class_p(self,X): if self.connected: plt.clf() x = X.detach().cpu().numpy() plt.plot(x[0, :]) name="confusion matrix" plt.rcParams['figure.figsize'] = (6, 6) #sn.set(font_scale=1.4) # for label size ax = sn.heatmap(cm, annot=True, annot_kws={"size": 11}) # font size ax.set(xlabel='ground truth', ylabel='predicted', title="Confusion Matrix") plt.tight_layout() opts = dict( resizeable=True ) self.viz.matplot(plt, win=name, opts=opts) def plot_boxplot(self, labels, t_stops, tmin=None, tmax=None): if self.connected: grouped = [t_stops[labels == i] for i in np.unique(labels)] #legend = ["class {}".format(i) for i in np.unique(labels)] plt.clf() name = "boxplot" plt.rcParams['figure.figsize'] = (9, 9) # sn.set(font_scale=1.4) # for label size ax = sn.boxplot(data=grouped, orient="h") ax.set_xlabel("t_stop") ax.set_ylabel("class") ax.set_xlim(tmin, tmax) #ax = sn.heatmap(cm, annot=True, annot_kws={"size": 11}, vmin=vmin, vmax=vmax) # font size #ax.set(xlabel='ground truth', ylabel='predicted', title=title) plt.tight_layout() opts = dict( resizeable=True ) self.viz.matplot(plt, win=name, opts=opts) pass def plot_epochs(self, data): """ Plots mean of epochs :param data: :return: """ if self.connected: data_mean_per_epoch = data.groupby(["mode", "epoch"]).mean() cols = data_mean_per_epoch.columns modes = data_mean_per_epoch.index.levels[0] for name in cols: if name in self.windows.keys(): win = self.windows[name] update = 'new' else: win = name # first log -> new window update = None opts = dict( title=name, showlegend=True, xlabel='epochs', ylabel=name) for mode in modes: epochs = data_mean_per_epoch[name].loc[mode].index values = data_mean_per_epoch[name].loc[mode] win = self.viz.line( X=epochs, Y=values, name=mode, win=win, opts=opts, update=update ) update='insert' self.windows[name] = win
class SingleCkptAnalysis(): def __init__(self, ckpt_dir, analysis_dir, check_frequency=120, delete_old_analyses=True, ckpt_freq=4000, weight_gif=True, recon_gif=True, model_layer_name='S1', probe_file_dir=None): self.vis = Visdom() self.check_frequency = check_frequency self.ckpt_dir = ckpt_dir self.analysis_dir = analysis_dir self.latest_analysis = self.get_latest_analyses(self.analysis_dir) self.delete_old_analyses = delete_old_analyses self.ckpt_freq = ckpt_freq self.weight_gif = weight_gif self.recon_gif = recon_gif self.model_layer_name = model_layer_name self.probe_file_dir = probe_file_dir print('[INFO] CHECKPOINT DIR: {}'.format(self.ckpt_dir)) print('[INFO] ANALYSIS DIR: {}'.format(self.analysis_dir)) print('[INFO] LATEST EXISTING ANALYSIS: {}'.format(self.latest_analysis)) def get_latest_analyses(self, dir): if not os.path.isdir(dir): os.mkdir(dir) existing = get_sorted_files(dir, 'analysis*', add_parent=False) return existing[-1] if any(existing) else None def get_current_ckpt(self): current_ckpts = get_sorted_files(self.ckpt_dir, add_parent=False) latest_ckpt = current_ckpts[-1] latest_ckpt_path = os.path.join(self.ckpt_dir, latest_ckpt) latest_ckpt_num = latest_ckpt.split('Checkpoint')[-1] return latest_ckpt, latest_ckpt_path, latest_ckpt_num def montage_weights(self, ckpt_dir, save_dir, sorted_indices): weight_filenames = get_sorted_files(ckpt_dir, keyword='*_W.pvp', add_parent=True) #weight_filenames = [os.path.join(ckpt_dir, f) for f in os.listdir(ckpt_dir) if '_W.pvp' in f] #weight_filenames.sort() save_dir = os.path.join(save_dir, 'Weights') os.mkdir(save_dir) gif = [] if self.weight_gif else None for i_filename, weight_filename in enumerate(weight_filenames): data = pv.readpvpfile(weight_filename) weights = data['values'] weights = weights[0, 0, :, :, :, 0] f, h, w = weights.shape weights = weights[sorted_indices, ...] gridh, gridw = h * int(np.ceil(np.sqrt(f))), w * int(np.ceil(np.sqrt(f))) grid = np.zeros([gridh, gridw]) count = 0 for i_h in range(0, gridh, h): for i_w in range(0, gridw, w): if count < f: grid[i_h:i_h+h, i_w:i_w+w] = bytescale_patch_np(weights[count, ...]) count += 1 grid[::h, :] = 255. grid[:, ::w] = 255. if not self.weight_gif: fig_name = os.path.split(weight_filename)[1][:-4] imwrite(os.path.join(save_dir, fig_name + '.png'), np.uint8(grid)) else: gif.append(np.uint8(grid)) mimsave(os.path.join(save_dir, 'weights.gif'), gif, fps=5) def plot_recs(self, ckpt_dir, save_dir): rec_paths = get_sorted_files(ckpt_dir, keyword='Frame*Recon_A.pvp', add_parent=True) if rec_paths == []: return save_dir = os.path.join(save_dir, 'Recons') gifs = {} if self.recon_gif else None if not os.path.isdir(save_dir): os.mkdir(save_dir) for i_rec_path, rec_path in enumerate(rec_paths): i_input_frame = int(''.join(filter(str.isdigit, os.path.split(rec_path)[1]))) input_filename = 'Frame{}_A.pvp'.format(i_input_frame) input_path = os.path.join(ckpt_dir, input_filename) input_batch, rec_batch = pv.readpvpfile(input_path)['values'], pv.readpvpfile(rec_path)['values'] n = input_batch.shape[0] frame_save_dir = os.path.join(save_dir, 'Frame{}'.format(i_input_frame)) if self.recon_gif else save_dir if not os.path.isdir(frame_save_dir) and not self.recon_gif: os.mkdir(frame_save_dir) for i_example, (input_ex, rec_ex) in enumerate(zip(input_batch, rec_batch)): if i_example not in list(gifs.keys()): gifs[i_example] = [] input_ex, rec_ex = input_ex[..., 0], rec_ex[..., 0] input_scaled, rec_scaled = bytescale_patch_np(input_ex), bytescale_patch_np(rec_ex) if np.sum(rec_scaled) == 0 and int(''.join([c for c in self.latest_analysis if c.isdigit()])) != 0: print('[WARNING] BATCH {} EXPLODED'.format(os.path.split(rec_path)[1])) divider = np.zeros([input_scaled.shape[0], int(input_scaled.shape[1]*0.05)]) pane = np.uint8(np.concatenate((input_scaled, divider, rec_scaled), 1)) if not self.recon_gif: imwrite(os.path.join(frame_save_dir, 'Example{}Input.png'.format(i_example)), pane) else: gifs[i_example].append(pane) [mimsave(os.path.join(save_dir, 'Recon_{}.gif'.format(k)), gifs[k], fps=15) for k in list(gifs.keys())] def plot_fraction_active(self, ckpt_dir, save_dir): filename = os.path.join(ckpt_dir, self.model_layer_name + '_A.pvp') active_sorted, feat_indices_sorted = get_fraction_active(filename) opts = dict(xlabel='Feature Number', ylabel='Fraction Active', title='Activations' + self.latest_analysis.split('-')[1]) self.vis.bar(active_sorted, win='frac_act_total', opts=opts) return feat_indices_sorted, active_sorted def plot_energy(self, save_dir): if not self.probe_file_dir: return files = get_sorted_files(self.probe_file_dir, keyword='EnergyProbe_*', add_parent=True) if files == []: return save_dir = os.path.join(save_dir, 'Energy') if not os.path.isdir(save_dir): os.mkdir(save_dir) for i_file, file in enumerate(files): name = os.path.split(file)[1].split('.')[0] + '.png' data = np.genfromtxt(file, delimiter=',', skip_header=1) end_x = [t for t in data[:, 0] if t % self.ckpt_freq == 0] end_x = int(max(end_x)) if end_x != [] else 0 start_x = end_x - self.ckpt_freq if end_x != 0 else 0 end_x = data.shape[0] if end_x == 0 else end_x data = data[start_x:end_x, :] fig = plt.figure(figsize=(20, 15)) subplot = fig.add_subplot(111) subplot.set_ylabel('Energy') subplot.set_xlabel('Timestep') subplot.plot(data[:, 0], data[:, -1]) plt.savefig(os.path.join(save_dir, name)) plt.close() def plot_adaptivetimescales(self, save_dir): if not self.probe_file_dir: return files = get_sorted_files(self.probe_file_dir, keyword='AdaptiveTimeScales*', add_parent=True) if files == []: return save_dir = os.path.join(save_dir, 'AdaptiveTimeScales') if not os.path.isdir(save_dir): os.mkdir(save_dir) for i_file, file in enumerate(files): name = os.path.split(file)[1].split('.')[0] with open(file, 'r+') as txtfile: reader = csv.reader(txtfile, delimiter=',') timescale_data = {'Timescale': [], 'TimescaleTrue': [], 'TimescaleMax': [], 'Time': []} for i_row, row in enumerate(reader): if len(row) == 1: timescale_data['Time'].append(float(row[0].split(' = ')[-1])) else: timescale_data['Timescale'].append(float(row[1].split(' = ')[-1])) timescale_data['TimescaleTrue'].append(float(row[2].split(' = ')[-1])) timescale_data['TimescaleMax'].append(float(row[3].split(' = ')[-1])) end_x = [t for t in timescale_data['Time'] if t % self.ckpt_freq == 0] end_x = int(max(end_x)) if end_x != [] else 0 start_x = end_x - self.ckpt_freq if end_x != 0 else 0 end_x = len(timescale_data['Time']) if end_x == 0 else end_x for key in list(timescale_data.keys())[:-1]: fig = plt.figure(figsize=(20, 15)) subplot = fig.add_subplot(111) subplot.set_ylabel(key) subplot.set_xlabel('Time') subplot.plot(timescale_data['Time'][start_x:end_x], timescale_data[key][start_x:end_x]) plt.savefig(os.path.join(save_dir, name + '_' + key + '.png')) plt.close() def analyze(self): while True: _, current_ckpt_dir, current_ckpt_num = self.get_current_ckpt() if 'analysis-' + current_ckpt_num != self.latest_analysis: year, month, day, hour, min, sec = get_current_time() print('[INFO] FOUND A NEW CHECKPOINT: {} ({}/{}/{} {}:{}:{} EST)' \ .format(current_ckpt_num, month, day, year, hour, min, sec)) self.latest_analysis = 'analysis-' + current_ckpt_num save_dir = os.path.join(self.analysis_dir, self.latest_analysis) os.mkdir(save_dir) act_fname = os.path.join(current_ckpt_dir, self.model_layer_name + '_A.pvp') sorted_feat_indices, sorted_activations = self.plot_fraction_active(current_ckpt_dir, save_dir) self.montage_weights(current_ckpt_dir, save_dir, sorted_feat_indices) self.plot_recs(current_ckpt_dir, save_dir) self.plot_energy(save_dir) self.plot_adaptivetimescales(save_dir) print('[INFO] ANALYSIS {} WRITE COMPLETE.'.format(current_ckpt_num)) if self.delete_old_analyses and len(glob(os.path.join(self.analysis_dir, 'analysis-*'))) > 1: print('[INFO] REMOVING OLD ANALYSIS FILES.') [rmtree(f) for f in glob(os.path.join(self.analysis_dir, 'analysis-*')) if f != save_dir] sleep(self.check_frequency)
class VisdomLogger(Logger): """Logs attack results to Visdom.""" def __init__(self, env="main", port=8097, hostname="localhost"): if not port_is_open(port, hostname=hostname): raise socket.error(f"Visdom not running on {hostname}:{port}") self.vis = Visdom(port=port, server=hostname, env=env) self.env = env self.port = port self.hostname = hostname self.windows = {} self.sample_rows = [] def __getstate__(self): state = {i: self.__dict__[i] for i in self.__dict__ if i != "vis"} return state def __setstate__(self, state): self.__dict__ = state self.vis = Visdom(port=self.port, server=self.hostname, env=self.env) def log_attack_result(self, result): text_a, text_b = result.diff_color(color_method="html") result_str = result.goal_function_result_str(color_method="html") self.sample_rows.append([result_str, text_a, text_b]) def log_summary_rows(self, rows, title, window_id): self.table(rows, title=title, window_id=window_id) def flush(self): self.table( self.sample_rows, title="Sample-Level Results", window_id="sample_level_results", ) def log_hist(self, arr, numbins, title, window_id): self.bar(arr, numbins=numbins, title=title, window_id=window_id) def text(self, text_data, title=None, window_id="default"): if window_id and window_id in self.windows: window = self.windows[window_id] self.vis.text(text_data, win=window) else: new_window = self.vis.text(text_data, opts=dict(title=title)) self.windows[window_id] = new_window def table(self, rows, window_id=None, title=None, header=None, style=None): """Generates an HTML table.""" if not window_id: window_id = title # Can provide either of these, if not title: title = window_id # or both. table = html_table_from_rows(rows, title=title, header=header, style_dict=style) self.text(table, title=title, window_id=window_id) def bar(self, X_data, numbins=10, title=None, window_id=None): window = None if window_id and window_id in self.windows: window = self.windows[window_id] self.vis.bar(X=X_data, win=window, opts=dict(title=title, numbins=numbins)) else: new_window = self.vis.bar(X=X_data, opts=dict(title=title, numbins=numbins)) if window_id: self.windows[window_id] = new_window def hist(self, X_data, numbins=10, title=None, window_id=None): window = None if window_id and window_id in self.windows: window = self.windows[window_id] self.vis.histogram(X=X_data, win=window, opts=dict(title=title, numbins=numbins)) else: new_window = self.vis.histogram(X=X_data, opts=dict(title=title, numbins=numbins)) if window_id: self.windows[window_id] = new_window
from visdom import Visdom import numpy as np import math vis = Visdom() # 单个条形图 vis.bar(X=np.random.rand(20)) # 堆叠条形图 vis.bar( X=np.abs(np.random.rand(5, 3)), opts=dict( title='堆叠条形图', stacked=True, legend=['Sina', '163', 'AliBaBa'], rownames=['2013', '2014', '2015', '2016', '2017'] ) ) # 分组条形图 vis.bar( X=np.random.rand(20, 3), opts=dict( title='分组条形图', stacked=False, legend=['A', 'B', 'C'] ) )
markersize=10, markercolor=np.random.randint(0, 255, (255, 3,)), ), ) # add new trace to scatter plot viz.updateTrace( X=np.random.rand(255), Y=np.random.rand(255), win=win, name='new_trace', ) # bar plots viz.bar(X=np.random.rand(20)) viz.bar( X=np.abs(np.random.rand(5, 3)), opts=dict( stacked=True, legend=['Facebook', 'Google', 'Twitter'], rownames=['2012', '2013', '2014', '2015', '2016'] ) ) viz.bar( X=np.random.rand(20, 3), opts=dict( stacked=False, legend=['The Netherlands', 'France', 'United States'] ) )
optimizer = torch.optim.Adam(model.parameters(), lr=settings.LEARNING_RATE) # Log file is namespaced with the current model log_file = "logs/{}_{}.csv".format( model.get_name(), settings.args.data_path.split("/")[-1].split(".json")[0]) if settings.VISUALIZE: # Visualization thorugh visdom viz = Visdom() loss_plot = viz.line(X=np.array([0]), Y=np.array([0]), opts=dict(showlegend=True, title="Loss")) hist_opts = settings.HIST_OPTS hist_opts["title"] = "Predicted star distribution" dist_hist = viz.bar(X=np.array([0, 0, 0]), opts=dict(title="Predicted stars")) real_dist_hist = viz.bar(X=np.array([0, 0, 0])) # Move stuff to GPU if settings.GPU: data_loader.pin_memory = True model.cuda() if settings.VISUALIZE: smooth_loss = 7 #approx 2.5^2 decay_rate = 0.99 smooth_real_dist = np.array([0, 0, 0, 0, 0], dtype=float) smooth_pred_dist = np.array([0, 0, 0, 0, 0], dtype=float) counter = 0
class Visualizer(object): def __init__(self, config: Config): # logging_level = logging._checkLevel("INFO") # logging.getLogger().setLevel(logging_level) # VisdomServer.start_server(port=VisdomServer.DEFAULT_PORT, env_path=config.vis_env_path) self.reinit(config) def reinit(self, config): self.config = config try: self.visdom = Visdom(env=config.visdom_env) self.connected = self.visdom.check_connection() if not self.connected: print( "Visdom server hasn't started, please run command 'python -m visdom.server' in terminal." ) # try: # print("Visdom server hasn't started, do you want to start it? ") # if 'y' in input("y/n: ").lower(): # os.popen('python -m visdom.server') # except Exception as e: # warn(e) except ConnectionError as e: warn("Can't open Visdom because " + e.strerror) with open(self.config.log_file, 'a') as f: info = "[{time}]Initialize Visdom\n".format( time=timestr('%m-%d %H:%M:%S')) info += str(self.config) f.write(info + '\n') def save(self, save_path: str = None) -> str: retstr = self.visdom.save([ self.config.visdom_env ]) # return current environments name in format of json try: ret = json.loads(retstr)[0] if ret == self.config.visdom_env: if isinstance(save_path, str): from shutil import copy copy(self.config.vis_env_path, save_path) print('Visdom Environment has saved into ' + save_path) else: print('Visdom Environment has saved into ' + self.config.vis_env_path) with open(self.config.vis_env_path, 'r') as fp: env_str = json.load(fp) return env_str except Exception as e: warn(e) return None def clear(self): self.visdom.close() @staticmethod def _to_numpy(value): if isinstance(value, t.Tensor): value = value.cpu().detach().numpy() elif isinstance(value, np.ndarray): pass else: value = np.array(value) if value.ndim == 0: value = value[np.newaxis] return value def plot(self, y, x, line_name, win, legend=None): # type:(float,float,str,str,list)->bool """Plot a (sequence) of y point(s) (each) with one x value(s), loop this method to draw whole plot""" update = None if not self.visdom.win_exists(win) else 'append' opts = dict(title=win) if legend is not None: opts["legend"] = legend y = Visualizer._to_numpy(y) x = Visualizer._to_numpy(x) return win == self.visdom.line(y, x, win=win, env=self.config.visdom_env, update=update, name=line_name, opts=opts) def bar(self, y, win, rowindices=None): opts = dict(title=win) y = Visualizer._to_numpy(y) if isinstance(rowindices, list) and len(rowindices) == len(y): opts["rownames"] = rowindices return win == self.visdom.bar(y, win=win, env=self.config.visdom_env, opts=opts) def log(self, msg, name, append=True, log_file=None): # type:(Visualizer,str,str,bool,str)->bool if log_file is None: log_file = self.config.log_file info = "[{time}]{msg}".format(time=timestr('%m-%d %H:%M:%S'), msg=msg) append = append and self.visdom.win_exists(name) ret = self.visdom.text(info, win=name, env=self.config.visdom_env, opts=dict(title=name), append=append) mode = 'a+' if append else 'w+' with open(log_file, mode) as f: f.write(info + '\n') return ret == name def log_process(self, num, total, msg, name, append=True): # type:(Visualizer,int,int,str,str,bool)->bool info = "[{time}]{msg}".format(time=timestr('%m-%d %H:%M:%S'), msg=msg) append = append and self.visdom.win_exists(name) ret = self.visdom.text(info, win=(name), env=self.config.visdom_env, opts=dict(title=name), append=append) with open(self.config.log_file, 'a') as f: f.write(info + '\n') self.process_bar(num, total, msg) return ret == name def process_bar(self, num, total, msg='', length=50): rate = num / total rate_num = int(rate * 100) clth = int(rate * length) if len(msg) > 0: msg += ':' # msg = msg.replace('\n', '').replace('\r', '') if rate_num == 100: r = '\r%s[%s%d%%]\n' % ( msg, '*' * length, rate_num, ) else: r = '\r%s[%s%s%d%%]' % ( msg, '*' * clth, '-' * (length - clth), rate_num, ) sys.stdout.write(r) sys.stdout.flush() return r.replace('\r', ':')
opts={ 'title': 'New Scatter', 'legend': ['Apple', 'Banana'], 'markersymbol': 'dot' }) # 3D散点图 viz.scatter(X=np.random.rand(100, 3), Y=Y, opts={ 'title': '3D Scatter', 'legend': ['Men', 'Women'], 'markersize': 5 }) # 柱状图 viz.bar(X=np.random.rand(20)) viz.bar( X=np.abs(np.random.rand(5, 3)), # 5个列,每列有3部分组成 opts={ 'stacked': True, 'legend': ['A', 'B', 'C'], 'rownames': ['2012', '2013', '2014', '2015', '2016'] }) viz.bar(X=np.random.rand(20, 3), opts={ 'stacked': False, 'legend': ['America', 'Britsh', 'China'] }) # 热力图,地理图,表面图
class Visualizer: def __init__(self, env="main", server="http://localhost", port=8097, base_url="/", http_proxy_host=None, http_proxy_port=None): self._viz = Visdom(env=env, server=server, port=port, http_proxy_host=http_proxy_host, http_proxy_port=http_proxy_port, use_incoming_socket=False) self._viz.close(env=env) def plot_line(self, values, steps, name, legend=None): if legend is None: opts = dict(title=name) else: opts = dict(title=name, legend=legend) self._viz.line(X=numpy.column_stack(steps), Y=numpy.column_stack(values), win=name, update='append', opts=opts) def plot_text(self, text, title, pre=True): _width = max([len(x) for x in text.split("\n")]) * 10 _heigth = len(text.split("\n")) * 20 _heigth = max(_heigth, 120) if pre: text = "<pre>{}</pre>".format(text) self._viz.text(text, win=title, opts=dict(title=title, width=min(_width, 400), height=min(_heigth, 400))) def plot_bar(self, data, labels, title): self._viz.bar(win=title, X=data, opts=dict(legend=labels, stacked=False, title=title)) def plot_scatter(self, data, labels, title): X = numpy.concatenate(data, axis=0) Y = numpy.concatenate( [numpy.full(len(d), i) for i, d in enumerate(data, 1)], axis=0) self._viz.scatter(win=title, X=X, Y=Y, opts=dict(legend=labels, title=title, markersize=5, webgl=True, width=400, height=400, markeropacity=0.5)) def plot_heatmap(self, data, labels, title): self._viz.heatmap( win=title, X=data, opts=dict( title=title, columnnames=labels[1], rownames=labels[0], width=700, height=700, layoutopts={ 'plotly': { 'xaxis': { 'side': 'top', 'tickangle': -60, # 'autorange': "reversed" }, 'yaxis': { 'autorange': "reversed" }, } }))
class VisdomLogger(Logger): def __init__(self, env='main', port=8097, hostname='localhost'): if not port_is_open(port, hostname=hostname): raise socket.error(f'Visdom not running on {hostname}:{port}') self.vis = Visdom(port=port, server=hostname, env=env) self.windows = {} self.sample_rows = [] def log_attack_result(self, result): text_a, text_b = result.diff_color(color_method='html') result_str = result.goal_function_result_str(color_method='html') self.sample_rows.append([result_str, text_a, text_b]) def log_summary_rows(self, rows, title, window_id): self.table(rows, title=title, window_id=window_id) def flush(self): self.table(self.sample_rows, title='Sample-Level Results', window_id='sample_level_results') def log_hist(self, arr, numbins, title, window_id): self.bar(arr, numbins=numbins, title=title, window_id=window_id) def text(self, text_data, title=None, window_id='default'): if window_id and window_id in self.windows: window = self.windows[window_id] self.vis.text(text_data, win=window) else: new_window = self.vis.text(text_data, opts=dict(title=title)) self.windows[window_id] = new_window def table(self, rows, window_id=None, title=None, header=None, style=None): """ Generates an HTML table. """ if not window_id: window_id = title # Can provide either of these, if not title: title = window_id # or both. table = html_table_from_rows(rows, title=title, header=header, style_dict=style) self.text(table_html, title=title, window_id=window_id) def bar(self, X_data, numbins=10, title=None, window_id=None): window = None if window_id and window_id in self.windows: window = self.windows[window_id] self.vis.bar(X=X_data, win=window, opts=dict(title=title, numbins=numbins)) else: new_window = self.vis.bar(X=X_data, opts=dict(title=title, numbins=numbins)) if window_id: self.windows[window_id] = new_window def hist(self, X_data, numbins=10, title=None, window_id=None): window = None if window_id and window_id in self.windows: window = self.windows[window_id] self.vis.histogram(X=X_data, win=window, opts=dict(title=title, numbins=numbins)) else: new_window = self.vis.histogram(X=X_data, opts=dict(title=title, numbins=numbins)) if window_id: self.windows[window_id] = new_window
win=win, name='new_trace', update='new') # 2D scatter plot with text labels: viz.scatter(X=np.random.rand(10, 2), opts=dict(textlabels=['Label %d' % (i + 1) for i in range(10)])) viz.scatter(X=np.random.rand(10, 2), Y=[1] * 5 + [2] * 3 + [3] * 2, opts=dict(legend=['A', 'B', 'C'], textlabels=['Label %d' % (i + 1) for i in range(10)])) # bar plots viz.bar(X=np.random.rand(20)) viz.bar(X=np.abs(np.random.rand(5, 3)), opts=dict(stacked=True, legend=['Facebook', 'Google', 'Twitter'], rownames=['2012', '2013', '2014', '2015', '2016'])) viz.bar(X=np.random.rand(20, 3), opts=dict(stacked=False, legend=['The Netherlands', 'France', 'United States'])) # histogram viz.histogram(X=np.random.rand(10000), opts=dict(numbins=20)) # heatmap viz.heatmap( X=np.outer(np.arange(1, 6), np.arange(1, 11)), opts=dict(
class VisdomImageHandler(logging.Handler): """ Logging Handler to show images and metric plots with visdom .. deprecated:: 0.1 :class:`VisdomImageHandler` will be removed in next release and is deprecated in favor of ``trixi.logging`` Modules .. warning:: :class:`VisdomImageHandler` will be removed in next release See Also -------- `Visdom` :class:`TrixiHandler` """ def __init__(self, port, prefix, log_freq_train, log_freq_val=1e10, level=logging.NOTSET, log_freq_img=1, **kwargs): """ Parameters ---------- port: int port of visdom-server prefix : str prefix of environment names log_freq_train : int Defines logging frequency for scores in train mode log_freq_val : int Defines logging frequency for scores in validation mode level : int (default: logging.NOTSET) logging level **kwargs: additional keyword arguments which are directly passed to visdom """ super().__init__(level=level) self.viz = Visdom(port=port, env=prefix, **kwargs) self.env_prefix = prefix self.log_freq_train = log_freq_train self.log_freq_val = log_freq_val self.curr_batch_train = 1 self.curr_batch_val = 1 self.curr_epoch_train = 1 self.curr_epoch_val = 1 self.metrics = {} self.val_metrics = {} self.plot_windows = {} self.image_windows = {} self.heatmap_windows = {} self.text_windows = {} self.bar_windows = {} self.curr_env_name = prefix self.curr_fold = None self.img_count = 0 self.log_freq_img = log_freq_img def emit(self, record): """ shows images and metric plots in visdom Parameters ---------- record : LogRecord entities to log Returns ------- None * if no connection to `visdom` could be found * if `record.msg` is not a dict """ # messages that cant be send fill (GPU-)RAM so return if no connection if not self.viz.check_connection(): return if not isinstance(record.msg, dict): return scores = record.msg.get("scores", {}) images = record.msg.get("images", {}) heatmaps = record.msg.get("heatmaps", {}) scalars = record.msg.get("scalars", {}) bars = record.msg.get("bars", {}) fold = record.msg.get("fold", "") text = record.msg.get("text", {}) plots = record.msg.get("plots", {}) if fold != self.curr_fold: self.curr_batch_train = 1 self.curr_batch_val = 1 self.curr_epoch_train = 1 self.curr_epoch_val = 1 self.metrics = {} self.val_metrics = {} self.plot_windows = {} self.image_windows = {} self.heatmap_windows = {} self.text_windows = {} self.bar_windows = {} if not fold and isinstance(fold, str): fold_name = self.env_prefix else: fold_name = self.env_prefix + "_fold_%02d_%s" % (fold, now()) else: fold_name = self.curr_env_name self.curr_fold = fold self.curr_env_name = fold_name # Log losses and metrics for i, metric_name in enumerate(scores.keys()): # handle validation scores if metric_name.startswith("val_"): metric_name = metric_name.split("_", maxsplit=1)[-1] if metric_name not in self.val_metrics: self.val_metrics[metric_name] = self._to_scalar( scores["val_" + metric_name]) else: self.val_metrics[metric_name] += self._to_scalar( scores["val_" + metric_name]) # handle train scores else: if metric_name not in self.metrics: self.metrics[metric_name] = self._to_scalar( scores[metric_name]) else: self.metrics[metric_name] += self._to_scalar( scores[metric_name] ) # Draw images self.img_count += 1 if (self.img_count % self.log_freq_img) == 0: for image_name, tensor in images.items(): if image_name not in self.image_windows: self.image_windows[image_name] = self.viz.image( self._to_image(tensor), opts={'title': image_name}, env=fold_name) else: self.viz.image(self._to_image(tensor.data), win=self.image_windows[image_name], opts={'title': image_name}, env=fold_name) self.img_count = 0 # draw heatmaps for heatmap_name, tensor in heatmaps.items(): heatmap = tensor[0].cpu().numpy() if heatmap_name not in self.heatmap_windows: self.heatmap_windows[heatmap_name] = self.viz.heatmap( heatmap, opts=dict(title=heatmap_name, colormap='hot'), env=fold_name) else: self.viz.heatmap(heatmap, win=self.heatmap_windows[heatmap_name], opts=dict(title=heatmap_name, colormap='hot'), env=fold_name) # visualize scalars for scalar_name, scalar_val in scalars.items(): text_str = "<font face = 'Arial' size = '4'>%s</font>" % \ str(self._to_scalar(scalar_val)) if scalar_name not in self.text_windows: self.text_windows[scalar_name] = self.viz.text(text_str, env=fold_name) else: self.viz.text(text_str, win=self.text_windows[scalar_name], env=fold_name) # draw bars for bar_name, bar_vals in bars.items(): if bar_name not in self.bar_windows: self.bar_windows[bar_name] = self.viz.bar( bar_vals, opts={"title": bar_name}, env=fold_name) else: self.viz.bar(bar_vals, win=self.bar_windows[bar_name], opts={"title": bar_name}, env=fold_name) # visualize text for text_name, val_str in text.items(): text_str = "<font face = 'Arial' size = '4'>%s</font>" % val_str if text_name not in self.text_windows: self.text_windows[text_name] = self.viz.text(text_str, env=fold_name) else: self.viz.text(text_str, win=self.text_windows[text_name], env=fold_name) # visualize plots for plot_name, plot_vals in plots.items(): if isinstance(plot_vals, dict): x_vals = plot_vals["x"] y_vals = plot_vals["y"] xlabel = plot_vals.get("xlabel", "") ylabel = plot_vals.get("ylabel", "") else: x_vals = np.array(plot_vals[0]) y_vals = np.array(plot_vals[1]) xlabel = "" ylabel = "" if plot_name not in self.plot_windows: self.plot_windows[plot_name] = self.viz.line( X=x_vals, Y=y_vals, opts={'xlabel': xlabel, 'ylabel': ylabel, 'title': plot_name}, env=fold_name) else: self.viz.line(X=x_vals, Y=y_vals, win=self.plot_windows[plot_name], opts={'xlabel': xlabel, 'ylabel': ylabel, 'title': plot_name}, env=fold_name) # End of epoch # decide which dict to log # only one epoch type at same type possible # train epoch ended if (self.curr_batch_train % self.log_freq_train) == 0: score_dict = self.metrics curr_batch = self.curr_batch_train curr_epoch = self.curr_epoch_train name = "train" self.curr_epoch_train += 1 self.curr_batch_train = 1 self.metrics = {} # validation epoch ended elif (self.curr_batch_val % self.log_freq_val) == 0: score_dict = self.val_metrics curr_batch = self.curr_batch_val curr_epoch = self.curr_epoch_val name = "val" self.curr_epoch_val += 1 self.curr_batch_val = 1 self.val_metrics = {} # no epoch ended else: score_dict = {} curr_epoch = 1 curr_batch = 1 if score_dict: # Plot losses for metric_name, metric in score_dict.items(): if metric_name not in self.plot_windows: self.plot_windows[metric_name] = self.viz.line( X=np.array([curr_epoch]), Y=np.array([metric / curr_batch]), opts={'xlabel': 'iterations', 'ylabel': metric_name, 'title': metric_name}, name=name, env=fold_name) else: self.viz.line(X=np.array([curr_epoch]), Y=np.array([metric / curr_batch]), win=self.plot_windows[metric_name], update='append', name=name, env=fold_name) else: is_val = False is_train = False for key in scores.keys(): if key.startswith("val_"): is_val = True else: is_train = True if is_val: self.curr_batch_val +=1 if is_train: self.curr_batch_train += 1 @staticmethod def _to_scalar(val): """ convert scalar wrapped in tensor or numpy array to float Parameters ---------- val: torch.Tensor or numpy array value to be converted Returns ------- float converted value """ if isinstance(val, np.ndarray): return np.asscalar(val) elif isinstance(val, torch.Tensor): return val.item() else: return float(val) @staticmethod def _to_image(tensor: torch.Tensor): """ convert image to numpy array Parameters ---------- tensor: entity which is convertible to numpy array image tensor Returns ------- np.ndarray converted image """ img = tensor[0].cpu().numpy() if img.shape[0] == 1: img = np.tile(img, (3, 1, 1)) img -= img.min() if img.max(): img *= 255/img.max() return img.astype(np.uint8)