def make_predictions_and_evaluate_gpu(conf,shot_list,loader,custom_path=None): y_prime,y_gold,disruptive = make_predictions_gpu(conf,shot_list,loader,custom_path) analyzer = PerformanceAnalyzer(conf=conf) roc_area = analyzer.get_roc_area(y_prime,y_gold,disruptive) shot_list.set_weights(analyzer.get_shot_difficulty(y_prime,y_gold,disruptive)) loss = get_loss_from_list(y_prime,y_gold,conf['data']['target']) return y_prime,y_gold,disruptive,roc_area,loss
def mpi_make_predictions_and_evaluate(conf,shot_list,loader,custom_path=None): y_prime,y_gold,disruptive = mpi_make_predictions(conf,shot_list,loader,custom_path) analyzer = PerformanceAnalyzer(conf=conf) roc_area = analyzer.get_roc_area(y_prime,y_gold,disruptive) shot_list.set_weights(analyzer.get_shot_difficulty(y_prime,y_gold,disruptive)) loss = get_loss_from_list(y_prime,y_gold,conf['data']['target']) return y_prime,y_gold,disruptive,roc_area,loss
def make_predictions_and_evaluate_multiple_times(conf, shot_list, loader, times, custom_path=None): y_prime, y_gold, disruptive = make_predictions(conf, shot_list, loader, custom_path) areas = [] losses = [] for T_min_curr in times: # if 'monitor_test' in conf['callbacks'].keys() and # conf['callbacks']['monitor_test']: conf_curr = deepcopy(conf) T_min_warn_orig = conf['data']['T_min_warn'] conf_curr['data']['T_min_warn'] = T_min_curr assert conf['data']['T_min_warn'] == T_min_warn_orig analyzer = PerformanceAnalyzer(conf=conf_curr) roc_area = analyzer.get_roc_area(y_prime, y_gold, disruptive) # shot_list.set_weights(analyzer.get_shot_difficulty(y_prime, y_gold, # disruptive)) loss = get_loss_from_list(y_prime, y_gold, conf['data']['target']) areas.append(roc_area) losses.append(loss) return areas, losses
def make_predictions_and_evaluate_gpu(conf, shot_list, loader): y_prime, y_gold, disruptive = make_predictions(conf, shot_list, loader) analyzer = PerformanceAnalyzer(conf=conf) roc_area = analyzer.get_roc_area(y_prime, y_gold, disruptive) loss = get_loss_from_list(y_prime, y_gold, conf['data']['target']) return y_prime, y_gold, disruptive, roc_area, loss