def evm_1d_norm(cr_target, cr_test, object_mask, noise_mask): cr_target_norm = cr_target / np.amax(np.abs(cr_target)) cr_test_norm = cr_test / np.amax(np.abs(cr_test)) obj_values_target = cr_target_norm[object_mask] obj_values_test = cr_test_norm[object_mask] if len(obj_values_target) == 0: print_('WARNING: no obj peak targets found in evm_1d_norm!') return np.nan evms = np.abs(obj_values_target - obj_values_test) / np.abs(obj_values_target) return np.average(evms)
def rd_obj_peak_log_mag_mse(rd_target, rd_test, object_mask, noise_mask): peaks_target = rd_target[object_mask] peaks_test = rd_test[object_mask] if len(peaks_target) == 0: print_('WARNING: no peaks found for evaluation metric.') return np.nan mag_target = np.abs(peaks_target) mag_test = np.abs(peaks_test) phase_mse = mean_squared_error(mag_target, mag_test) return phase_mse
def forward(self, output_re_im, target_re_im, object_mask, noise_mask): object_mask = object_mask.to(device) noise_mask = noise_mask.to(device) neg_sinr_mean = 0 num_packets = target_re_im.shape[0] if self.data_content is DataContent.COMPLEX_PACKET_RD: for p in range(num_packets): output_re_im_packet = output_re_im[p] neg_sinr_mean -= sinr_from_re_im_format(output_re_im_packet, object_mask, noise_mask) else: print_('WARNING: Not implemented yet.') assert False neg_sinr_mean /= num_packets return neg_sinr_mean
def forward(self, output_re_im, target_re_im, object_mask, noise_mask): object_mask = object_mask.to(device) loss = 0 num_packets = target_re_im.shape[0] num_re = int(target_re_im.shape[2] / 2) if self.data_content is DataContent.COMPLEX_PACKET_RD: for p in range(num_packets): output_re_im_packet = output_re_im[p] target_re_im_packet = target_re_im[p] output_re_packet = output_re_im_packet[:, :num_re] output_im_packet = output_re_im_packet[:, num_re:] target_re_packet = target_re_im_packet[:, :num_re] target_im_packet = target_re_im_packet[:, num_re:] output_peaks_re = torch.masked_select(output_re_packet, object_mask) output_peaks_im = torch.masked_select(output_im_packet, object_mask) target_peaks_re = torch.masked_select(target_re_packet, object_mask) target_peaks_im = torch.masked_select(target_im_packet, object_mask) phase_target = torch.atan(target_peaks_im / target_peaks_re) phase_output = torch.atan(output_peaks_im / output_peaks_re) target_max_mag = torch.sqrt(target_re_packet ** 2 + target_im_packet ** 2).view(-1).max() target_re_packet_log_mag = target_re_packet / target_max_mag target_im_packet_log_mag = target_im_packet / target_max_mag target_log_mag = 10 * torch.log10(torch.sqrt(target_re_packet_log_mag ** 2 + target_im_packet_log_mag ** 2)) target_log_mag = torch.masked_select(target_log_mag, object_mask) output_max_mag = torch.sqrt(output_re_packet ** 2 + output_im_packet ** 2).view(-1).max() output_re_packet_log_mag = output_re_packet / output_max_mag output_im_packet_log_mag = output_im_packet / output_max_mag output_log_mag = 10 * torch.log10(torch.sqrt(output_re_packet_log_mag ** 2 + output_im_packet_log_mag ** 2)) output_log_mag = torch.masked_select(output_log_mag, object_mask) loss += self.w_re_im * self.mse(output_re_im, target_re_im) +\ self.w_mag * self.mse(output_log_mag, target_log_mag) +\ self.w_phase * self.mse(phase_output, phase_target) else: print_('WARNING: Not implemented yet.') assert False loss /= num_packets return loss
def rd_obj_peak_phase_mse(rd_target, rd_test, object_mask, noise_mask): peaks_target = rd_target[object_mask] peaks_test = rd_test[object_mask] if len(peaks_target) == 0: print_('WARNING: no peaks found for evaluation metric.') return np.nan peaks_target_imag = np.imag(peaks_target) peaks_target_real = np.real(peaks_target) peaks_phase_target = np.arctan(peaks_target_imag.astype('float') / peaks_target_real.astype('float')) peaks_test_imag = np.imag(peaks_test) peaks_test_real = np.real(peaks_test) peaks_phase_test = np.arctan(peaks_test_imag.astype('float') / peaks_test_real.astype('float')) phase_mse = mean_squared_error(peaks_phase_target, peaks_phase_test) return phase_mse
def evaluate_rd_log_mag(model, dataset, phase): print_('# # Evaluation: {} # #'.format(phase)) num_channels = dataset.get_num_channels() # dataloader with batch_size = number of samples per packet --> evaluate one packet at once for evaluation metrics num_samples_for_evaluation_metrics = DataContent.num_samples_for_rd_evaluation( dataset.data_content, dataset.num_ramps_per_packet) dataloader = data_loader_for_dataset(dataset, num_samples_for_evaluation_metrics, shuffle=False) since_test = time.time() model.eval() packet_idx = 0 signals_for_quality_measures = { Signal.PREDICTION: None, Signal.PREDICTION_SUBSTITUDE: None, Signal.CLEAN: None, Signal.INTERFERED: None, Signal.CLEAN_NOISE: None, Signal.BASELINE_ZERO_SUB: None } functions_for_quality_measures = [ EvaluationFunction.SINR_RD_LOG_MAG, EvaluationFunction.PEAK_MAG_MSE ] quality_measures = np.zeros( (len(functions_for_quality_measures), len(signals_for_quality_measures), len(dataloader))) for inputs, labels, filter_mask, _, _ in dataloader: inputs = inputs[filter_mask].to(device) interference_mask = torch.tensor(dataset.get_sample_interference_mask( packet_idx, num_samples_for_evaluation_metrics), dtype=torch.uint8)[filter_mask] if inputs.shape[0] == 0: packet_idx += 1 continue prediction_interf_substi = np.array( dataset.get_target_original_scaled_re_im( packet_idx, num_samples_for_evaluation_metrics), copy=True) prediction = np.zeros(prediction_interf_substi.shape) original_indices = np.nonzero(filter_mask.numpy())[0] with torch.set_grad_enabled(False): outputs = model(inputs) if RESIDUAL_LEARNING: packet_prediction = inputs.cpu().numpy() - outputs.cpu().numpy() else: packet_prediction = outputs.cpu().numpy() for i in range(len(outputs)): original_index = original_indices[i] prediction[original_index] = packet_prediction[i] if interference_mask[i]: prediction_interf_substi[original_index] = packet_prediction[i] prediction_ri = dataset.packet_in_target_format_to_complex( prediction, packet_idx) prediction_original_scale_ri = dataset.inverse_scale(prediction_ri, is_y=True) prediction_interf_substi_ri = dataset.packet_in_target_format_to_complex( prediction_interf_substi, packet_idx) prediction_interf_substi_original_scale_ri = dataset.inverse_scale( prediction_interf_substi_ri, is_y=True) if dataset.data_source == DataSource.DENOISE_REAL_IMAG_RAMP: prediction_rd = calculate_velocity_fft( prediction_original_scale_ri[:, 0, :]) prediction_substi_rd = calculate_velocity_fft( prediction_interf_substi_original_scale_ri[:, 0, :]) else: prediction_rd = prediction_original_scale_ri[0] prediction_substi_rd = prediction_interf_substi_original_scale_ri[ 0] clean_rd = dataset.get_scene_rd_clean(packet_idx) interf_rd = dataset.get_scene_rd_interf(packet_idx) clean_noise_rd = dataset.get_scene_rd_original(packet_idx) zero_substi_baseline_rd = dataset.get_scene_rd_zero_substitude_in_time_domain( packet_idx) target_object_mask, target_noise_mask = dataset.get_scene_rd_object_and_noise_masks( packet_idx) signals_for_quality_measures[Signal.PREDICTION] = prediction_rd signals_for_quality_measures[ Signal.PREDICTION_SUBSTITUDE] = prediction_substi_rd signals_for_quality_measures[Signal.CLEAN] = clean_rd signals_for_quality_measures[Signal.INTERFERED] = interf_rd signals_for_quality_measures[Signal.CLEAN_NOISE] = clean_noise_rd signals_for_quality_measures[ Signal.BASELINE_ZERO_SUB] = zero_substi_baseline_rd for i, func in enumerate(functions_for_quality_measures): for j, signal_name in enumerate(signals_for_quality_measures): signal = signals_for_quality_measures[signal_name] quality_measures[i, j, packet_idx] = func(clean_rd, signal, target_object_mask, target_noise_mask) scene_idx = int(packet_idx / num_channels) if visualize and scene_idx in [0, 1 ] and packet_idx % num_channels == 0: plot_rd_matrix_for_packet( clean_rd, prediction_rd, prediction_substi_rd, interf_rd, zero_substi_baseline_rd, phase, packet_idx, dataset.data_source == DataSource.DENOISE_REAL_IMAG_RAMP, is_log_mag=True) object_mask, noise_mask = dataset.get_scene_rd_object_and_noise_masks( packet_idx) plot_rd_noise_mask( noise_mask, 'RD Noise mask', 'eval_{}_rd_noise_mask_p{}'.format(phase, packet_idx)) plot_object_mag_cuts(prediction_rd, prediction_substi_rd, clean_rd, clean_noise_rd, interf_rd, zero_substi_baseline_rd, object_mask, packet_idx, phase, is_rd=True, is_log_mag=True) packet_idx += 1 time_test = time.time() - since_test metrics = [] metric_labels = [] main_metric = None for i, func in enumerate(functions_for_quality_measures): func_metrics = [] func_metric_labels = [] sig_labels = [] for j, signal in enumerate(signals_for_quality_measures): quality_measures_per_func_sign = quality_measures[i, j] count_nans = np.count_nonzero( np.isnan(quality_measures_per_func_sign)) if count_nans > 0: quality_measures_per_func_sign = quality_measures_per_func_sign[ np.logical_not(np.isnan(quality_measures_per_func_sign))] print_( 'WARNING: quality measure "{}" produces {} nans!!'.format( func.label(), count_nans)) func_metrics.append(np.mean(quality_measures_per_func_sign)) metric_label = '{} {}'.format(func.label(), signal.label()) func_metric_labels.append(metric_label) sig_labels.append(signal.label()) if func is EvaluationFunction.SINR_RD_LOG_MAG and signal is Signal.PREDICTION: main_metric = func_metrics[-1] plot_values(quality_measures[i], sig_labels, func.label(), phase) metrics.extend(func_metrics) metric_labels.extend(func_metric_labels) if verbose: print_evaluation_summary(time_test, phase, metrics, metric_labels) return main_metric
def print_evaluation_summary(time_elapsed, phase, snrs=None, snr_labels=None, accuracy=None): print_() print_('{:<24}: {}'.format('data', phase)) print_('{:<24}: {:.0f}m {:.0f}s'.format('duration', time_elapsed // 60, time_elapsed % 60)) if snrs is not None: for i in range(len(snrs)): print_('{:<24}: {:>22.10f}'.format(snr_labels[i], snrs[i])) if accuracy is not None: print_('{:<24}: {:>22.10f}'.format('accuracy', accuracy)) print_()
def train_with_hyperparameter_config(dataset, hyperparameters, task, is_classification=False, shuffle_data=True): optimization_algo = hyperparameters.optimization_algo criterion = hyperparameters.criterion scheduler_partial = hyperparameters.scheduler_partial batch_size = hyperparameters.batch_size learning_rate = hyperparameters.learning_rate num_epochs = hyperparameters.num_epochs num_model_initializations = hyperparameters.num_model_initializations output_size = hyperparameters.output_size model = hyperparameters.model if is_classification: prediction_size = 1 else: prediction_size = output_size if verbose: print_('Running task: {}'.format(task)) print_('Max Epochs: {}'.format(num_epochs)) print_() print_(hyperparameters) print_() print_('# Training #') train_loader = data_loader_for_dataset(dataset, batch_size=batch_size, shuffle=shuffle_data) val_loader = data_loader_for_dataset(dataset.clone_for_new_active_partition(DatasetPartition.VALIDATION), batch_size=batch_size, shuffle=False) dataloaders = {'train': train_loader, 'val': val_loader} dataset_sizes = {'train': len(train_loader.dataset), 'val': len(val_loader.dataset)} count_batch_iterations = {'train': int(dataset_sizes['train'] / batch_size), 'val': int(dataset_sizes['val'] / batch_size)} try: criterion.weight = criterion.weight.to(device) except AttributeError: pass since = time.time() best_val_loss = float('inf') best_train_loss = float('inf') best_model_state_dict = None last_acc = {'train': None, 'val': None} use_validation = dataset_sizes['val'] > 0 stopping_strategy = EarlyStopping(steps_to_wait=50) model = model.to(device) # # # # # train and evaluate the configuration # # # # # for model_initialization in range(num_model_initializations): model.reset() try: if tensorboard_logging: model.set_tensorboardx_logging_active(tensorboard_logging) except AttributeError: warnings.warn('Model does not support tensorboard logging!') pass optimizer = optimization_algo(params=model.parameters(), lr=learning_rate) if verbose: print_('Model initialization {}/{}'.format(model_initialization + 1, num_model_initializations)) line = construct_formatted_values_headline() print_(line) losses = {'train': np.zeros(num_epochs), 'val': np.zeros(num_epochs)} if scheduler_partial is not None: scheduler = scheduler_partial(optimizer) batch_iteration = {'train': 0, 'val': 0} for epoch in range(num_epochs): epoch_start_time = time.time() try: scheduler.step() except UnboundLocalError: pass if tensorboard_logging: for param_group in optimizer.param_groups: tensorboard_writer.add_scalar('data/learning_rate', param_group['lr'], epoch) for phase in ['train', 'val']: if phase == 'train': model.train() else: model.eval() running_loss = 0.0 running_corrects = 0 curr_count_sample = 0 mem_usage = print_torch_mem_usage(None, print_mem=False) for inputs_batch, targets_batch, filter_mask_batch, object_masks, noise_masks in dataloaders[phase]: inputs_for_training = inputs_batch[filter_mask_batch].to(device) targets_for_learning = targets_batch[filter_mask_batch].to(device) sample_batch_size = inputs_for_training.shape[0] if sample_batch_size <= 1: # sample_batch_size == 1: does not work with batch_norm warnings.warn('Skipping batch with size {}. Batch norm requires a batch size >= 2.'.format(sample_batch_size)) continue try: model.init_hidden_state() # required by LSTM except AttributeError: pass # track history only if in training phase with torch.set_grad_enabled(phase == 'train'): outputs = model(inputs_for_training) if RESIDUAL_LEARNING: targets_for_loss = inputs_for_training - targets_for_learning else: targets_for_loss = targets_for_learning loss = criterion(outputs, targets_for_loss, object_masks, noise_masks) # backward + optimize only if in training phase if phase == 'train': # zero the parameter gradients optimizer.zero_grad() loss.backward() optimizer.step() running_loss += ObjectiveFunction.loss_to_running_loss(loss.item(), sample_batch_size) if is_classification: _, class_predictions = torch.max(outputs.detach(), 1) running_corrects += torch.sum(class_predictions == targets_for_learning.detach()).double() if tensorboard_logging: tensorboard_writer.add_scalar('data/{}_batch_loss'.format(phase), loss.item(), batch_iteration[phase]) if phase == 'train' and batch_iteration[phase] == 0: try: tensorboard_writer.add_graph(model, inputs_for_training) except AssertionError: warnings.warn('Model does not support tensorboard graphs!') curr_count_sample += sample_batch_size current_loss = ObjectiveFunction.loss_from_running_loss(running_loss, curr_count_sample) if verbose and not is_cluster_run and batch_iteration[phase] % max(int(count_batch_iterations[phase]/ 100), 1) == 0: if phase == 'train': line = '\r' + construct_formatted_values_line(epoch, current_loss, best_train_loss, None, best_val_loss, time.time() - epoch_start_time, is_classification, running_corrects / (curr_count_sample * prediction_size), None) else: line = '\r' + construct_formatted_values_line(epoch, losses['train'][epoch], best_train_loss, current_loss, best_val_loss, time.time() - epoch_start_time, is_classification, last_acc['train'], running_corrects / (curr_count_sample * prediction_size)) sys.stdout.write(line) sys.stdout.flush() batch_iteration[phase] += 1 if dataset_sizes[phase] > 0: last_acc[phase] = running_corrects / (dataset_sizes[phase] * prediction_size) if dataset_sizes[phase] > 0: losses[phase][epoch] = ObjectiveFunction.loss_from_running_loss(running_loss, dataset_sizes[phase]) if tensorboard_logging: tensorboard_writer.add_scalar('data/{}_loss'.format(phase), losses[phase][epoch], epoch) if tensorboard_logging: for name, param in model.named_parameters(): tensorboard_writer.add_histogram(name, param.clone().cpu().data.numpy(), epoch) if (use_validation and phase == 'val' and losses['val'][epoch] < best_val_loss)\ or (not use_validation and phase == 'train' and losses['train'][epoch] < best_train_loss): best_model_state_dict = model.state_dict() best_val_loss = losses['val'][epoch] best_train_loss = losses['train'][epoch] epoch_duration = time.time() - epoch_start_time try: criterion.next_epoch() except AttributeError: pass if verbose: line = '\r' + construct_formatted_values_line(epoch, losses['train'][epoch], best_train_loss, losses['val'][epoch], best_val_loss, epoch_duration, is_classification, last_acc['train'], last_acc['val'], mem_usage=mem_usage) print_(line) if (use_validation and stopping_strategy.should_stop(losses['val'][epoch], epoch))\ or (not use_validation and stopping_strategy.should_stop(losses['train'][epoch], epoch)): losses['val'] = losses['val'][:epoch+1] losses['train'] = losses['train'][:epoch+1] break if visualize: plot_losses(losses) if verbose: print_() time_elapsed = time.time() - since model.load_state_dict(best_model_state_dict) snr = -1 if verbose: print_() if dataset_sizes['val'] > 0: snrs = [] if dataset.data_content in [DataContent.COMPLEX_PACKET_RD, DataContent.COMPLEX_RAMP]: snrs.append(evaluate_rd(model, dataloaders['val'].dataset, 'val')) if dataset.data_content in [DataContent.REAL_PACKET_RD]: snrs.append(evaluate_rd_log_mag(model, dataloaders['val'].dataset, 'val')) snr = np.mean(snrs) test_dataset = dataset.clone_for_new_active_partition(DatasetPartition.TEST) if len(test_dataset) > 0: if dataset.data_content in [DataContent.COMPLEX_PACKET_RD, DataContent.COMPLEX_RAMP]: evaluate_rd(model, test_dataset, 'test') if dataset.data_content in [DataContent.REAL_PACKET_RD]: evaluate_rd_log_mag(model, test_dataset, 'test') print_() evaluation_result = EvaluationResult( hyperparameters, task_id, task, best_train_loss, best_val_loss, time_elapsed, snr, JOB_DIR + '/model' ) if tensorboard_logging: tensorboard_writer.close() return model, hyperparameters, evaluation_result
def evaluate_rd(model, dataset, phase): # # config signals and evaluation functions # # # signals_for_quality_measures = { # Signal.PREDICTION: None, # Signal.PREDICTION_SUBSTITUDE: None, # Signal.CLEAN: None, # Signal.INTERFERED: None, # Signal.CLEAN_NOISE: None, # Signal.BASELINE_ZERO_SUB: None # } signals_for_quality_measures = { Signal.PREDICTION: None, Signal.INTERFERED: None, Signal.CLEAN_NOISE: None } # functions_for_quality_measures = [EvaluationFunction.SINR_RD, # EvaluationFunction.EVM_RD, # EvaluationFunction.PHASE_MSE_RD, # EvaluationFunction.LOG_MAG_MSE_RD, # EvaluationFunction.EVM_RD_NORM] functions_for_quality_measures = [EvaluationFunction.SINR_RD] # signals_for_quality_measures = { # Signal.PREDICTION: None, # Signal.CLEAN: None, # Signal.INTERFERED: None, # Signal.CLEAN_NOISE: None, # Signal.BASELINE_ZERO_SUB: None # } signals_for_cr_quality_measures = { Signal.PREDICTION: None, Signal.INTERFERED: None, Signal.CLEAN_NOISE: None } # functions_for_cr_quality_measures = [EvaluationFunction.SINR_CR, # EvaluationFunction.EVM_CR, # EvaluationFunction.EVM_CR_NORM] functions_for_cr_quality_measures = [EvaluationFunction.SINR_CR] # # end config print_('# Evaluation: {} #'.format(phase)) num_channels = dataset.get_num_channels() # dataloader with batch_size = number of samples per packet --> evaluate one packet at once for evaluation metrics num_samples_for_evaluation_metrics = DataContent.num_samples_for_rd_evaluation(dataset.data_content, dataset.num_ramps_per_packet) dataloader = data_loader_for_dataset(dataset, num_samples_for_evaluation_metrics, shuffle=False) since_test = time.time() model.eval() packet_idx = 0 quality_measures = np.zeros((len(functions_for_quality_measures), len(signals_for_quality_measures), len(dataloader))) quality_measures_cr = np.zeros((len(functions_for_cr_quality_measures), len(signals_for_quality_measures), int(len(dataloader)/num_channels))) channel_idx = 0 scene_size = (dataset.num_fast_time_samples, dataset.num_ramps_per_packet, num_channels) angular_spectrum = np.zeros(scene_size, dtype=np.complex128) angular_spectrum_clean = np.zeros(scene_size, dtype=np.complex128) angular_spectrum_interf = np.zeros(scene_size, dtype=np.complex128) angular_spectrum_original = np.zeros(scene_size, dtype=np.complex128) angular_spectrum_zero_substi = np.zeros(scene_size, dtype=np.complex128) for inputs, labels, filter_mask, _, _ in dataloader: inputs = inputs[filter_mask].to(device) interference_mask = torch.tensor(dataset.get_sample_interference_mask(packet_idx, num_samples_for_evaluation_metrics), dtype=torch.uint8)[filter_mask] if inputs.shape[0] == 0: packet_idx += 1 continue prediction_interf_substi = np.array(dataset.get_target_original_scaled_re_im(packet_idx, num_samples_for_evaluation_metrics), copy=True) prediction = np.zeros(prediction_interf_substi.shape) original_indices = np.nonzero(filter_mask.numpy())[0] with torch.set_grad_enabled(False): outputs = model(inputs) if RESIDUAL_LEARNING: packet_prediction = inputs.cpu().numpy() - outputs.cpu().numpy() else: packet_prediction = outputs.cpu().numpy() for i in range(len(outputs)): original_index = original_indices[i] prediction[original_index] = packet_prediction[i] if interference_mask[i]: prediction_interf_substi[original_index] = packet_prediction[i] prediction_ri = dataset.packet_in_target_format_to_complex(prediction, packet_idx) prediction_original_scale_ri = dataset.inverse_scale(prediction_ri, is_y=True) prediction_interf_substi_ri = dataset.packet_in_target_format_to_complex(prediction_interf_substi, packet_idx) prediction_interf_substi_original_scale_ri = dataset.inverse_scale(prediction_interf_substi_ri, is_y=True) if dataset.data_source == DataSource.DENOISE_REAL_IMAG_RAMP: prediction_rd = calculate_velocity_fft(prediction_original_scale_ri[:, 0, :]) prediction_substi_rd = calculate_velocity_fft(prediction_interf_substi_original_scale_ri[:, 0, :]) else: prediction_rd = prediction_original_scale_ri[0] prediction_substi_rd = prediction_interf_substi_original_scale_ri[0] clean_rd = dataset.get_scene_rd_clean(packet_idx) interf_rd = dataset.get_scene_rd_interf(packet_idx) clean_noise_rd = dataset.get_scene_rd_original(packet_idx) zero_substi_baseline_rd = dataset.get_scene_rd_zero_substitude_in_time_domain(packet_idx) target_object_mask, target_noise_mask = dataset.get_scene_rd_object_and_noise_masks(packet_idx) if Signal.PREDICTION in signals_for_quality_measures: signals_for_quality_measures[Signal.PREDICTION] = prediction_rd if Signal.PREDICTION_SUBSTITUDE in signals_for_quality_measures: signals_for_quality_measures[Signal.PREDICTION_SUBSTITUDE] = prediction_substi_rd if Signal.CLEAN in signals_for_quality_measures: signals_for_quality_measures[Signal.CLEAN] = clean_rd if Signal.INTERFERED in signals_for_quality_measures: signals_for_quality_measures[Signal.INTERFERED] = interf_rd if Signal.CLEAN_NOISE in signals_for_quality_measures: signals_for_quality_measures[Signal.CLEAN_NOISE] = clean_noise_rd if Signal.BASELINE_ZERO_SUB in signals_for_quality_measures: signals_for_quality_measures[Signal.BASELINE_ZERO_SUB] = zero_substi_baseline_rd for i, func in enumerate(functions_for_quality_measures): for j, signal_name in enumerate(signals_for_quality_measures): signal = signals_for_quality_measures[signal_name] quality_measures[i, j, packet_idx] = func(clean_rd, signal, target_object_mask, target_noise_mask) scene_idx = int(packet_idx / num_channels) if visualize and scene_idx in [0, 1] and channel_idx == 0: plot_rd_matrix_for_packet( clean_rd, prediction_rd, prediction_substi_rd, interf_rd, zero_substi_baseline_rd, phase, packet_idx, dataset.data_source == DataSource.DENOISE_REAL_IMAG_RAMP) object_mask, noise_mask = dataset.get_scene_rd_object_and_noise_masks(packet_idx) plot_rd_noise_mask(noise_mask, 'RD Noise mask', 'eval_{}_rd_noise_mask_p{}'.format(phase, packet_idx)) plot_object_mag_cuts(prediction_rd, prediction_substi_rd, clean_rd, clean_noise_rd, interf_rd, zero_substi_baseline_rd, object_mask, packet_idx, phase, is_rd=True) plot_object_phase_cuts(prediction_rd, prediction_substi_rd, clean_rd, clean_noise_rd, interf_rd, zero_substi_baseline_rd, object_mask, packet_idx, phase, is_rd=True) angular_spectrum[:, :, channel_idx] = prediction_rd angular_spectrum_clean[:, :, channel_idx] = clean_rd angular_spectrum_interf[:, :, channel_idx] = interf_rd angular_spectrum_original[:, :, channel_idx] = clean_noise_rd angular_spectrum_zero_substi[:, :, channel_idx] = zero_substi_baseline_rd if channel_idx == num_channels-1: target_object_mask, _ = dataset.get_scene_rd_object_and_noise_masks(packet_idx) rows, columns = np.nonzero(target_object_mask) for obj_idx in range(len(rows)): cr_prediction = calculate_cross_range_fft(angular_spectrum[rows[obj_idx], columns[obj_idx]]) cr_clean = calculate_cross_range_fft(angular_spectrum_clean[rows[obj_idx], columns[obj_idx]]) cr_interf = calculate_cross_range_fft(angular_spectrum_interf[rows[obj_idx], columns[obj_idx]]) cr_original = calculate_cross_range_fft(angular_spectrum_original[rows[obj_idx], columns[obj_idx]]) cr_zero_substi = calculate_cross_range_fft(angular_spectrum_zero_substi[rows[obj_idx], columns[obj_idx]]) if Signal.PREDICTION in signals_for_cr_quality_measures: signals_for_cr_quality_measures[Signal.PREDICTION] = cr_prediction if Signal.CLEAN in signals_for_cr_quality_measures: signals_for_cr_quality_measures[Signal.CLEAN] = cr_clean if Signal.INTERFERED in signals_for_cr_quality_measures: signals_for_cr_quality_measures[Signal.INTERFERED] = cr_interf if Signal.CLEAN_NOISE in signals_for_cr_quality_measures: signals_for_cr_quality_measures[Signal.CLEAN_NOISE] = cr_original if Signal.BASELINE_ZERO_SUB in signals_for_cr_quality_measures: signals_for_cr_quality_measures[Signal.BASELINE_ZERO_SUB] = cr_zero_substi target_object_mask_cr, target_noise_mask_cr = dataset.get_scene_cr_object_and_noise_masks(packet_idx, rows[obj_idx], columns[obj_idx]) for i, func in enumerate(functions_for_cr_quality_measures): for j, signal_name in enumerate(signals_for_cr_quality_measures): signal = signals_for_cr_quality_measures[signal_name] quality_measures_cr[i, j, scene_idx] += func(cr_clean, signal, target_object_mask_cr, target_noise_mask_cr) if visualize and scene_idx in [0, 1, 2] and obj_idx in [0, 1]: x_vec = basis_vec_cross_range(rows[obj_idx]) plot_cross_ranges(obj_idx, rows, columns, phase, scene_idx, x_vec, angular_spectrum, angular_spectrum_clean, angular_spectrum_interf, angular_spectrum_original, angular_spectrum_zero_substi, target_object_mask_cr) if len(rows) > 0: quality_measures_cr[:, :, scene_idx] /= len(rows) channel_idx = 0 else: channel_idx += 1 packet_idx += 1 time_test = time.time() - since_test metrics = [] metric_labels = [] main_metric = None for i, func in enumerate(functions_for_quality_measures): func_metrics = [] func_metric_labels = [] sig_labels = [] for j, signal in enumerate(signals_for_quality_measures): quality_measures_per_func_sign = quality_measures[i, j] count_nans = np.count_nonzero(np.isnan(quality_measures_per_func_sign)) if count_nans > 0: quality_measures_per_func_sign = quality_measures_per_func_sign[~np.isnan(quality_measures_per_func_sign)] print_('WARNING: quality measure "{}" produces {} nans!!'.format(func.label(), count_nans)) func_metrics.append(np.mean(quality_measures_per_func_sign)) metric_label = '{} {}'.format(func.label(), signal.label()) func_metric_labels.append(metric_label) sig_labels.append(signal.label()) if func is EvaluationFunction.SINR_RD and signal is Signal.PREDICTION: main_metric = func_metrics[-1] plot_values(quality_measures[i], sig_labels, func.label(), phase) metrics.extend(func_metrics) metric_labels.extend(func_metric_labels) for i, func in enumerate(functions_for_cr_quality_measures): func_metrics = [] func_metric_labels = [] sig_labels = [] for j, signal in enumerate(signals_for_cr_quality_measures): func_metrics.append(np.mean(quality_measures_cr[i, j])) metric_label = '{} {}'.format(func.label(), signal.label()) func_metric_labels.append(metric_label) sig_labels.append(signal.label()) plot_values(quality_measures_cr[i], sig_labels, func.label(), phase) metrics.extend(func_metrics) metric_labels.extend(func_metric_labels) if verbose: print_evaluation_summary(time_test, phase, metrics, metric_labels) return main_metric