def main(out_dir=None, data_dir=None, cv_data_dir=None, score_dirs=[], fusion_method='sum', prune_imu=None, standardize=None, decode=None, plot_predictions=None, results_file=None, sweep_param_name=None, gpu_dev_id=None, model_params={}, cv_params={}, train_params={}, viz_params={}): data_dir = os.path.expanduser(data_dir) out_dir = os.path.expanduser(out_dir) score_dirs = tuple(map(os.path.expanduser, score_dirs)) if cv_data_dir is not None: cv_data_dir = os.path.expanduser(cv_data_dir) logger = utils.setupRootLogger(filename=os.path.join(out_dir, 'log.txt')) if results_file is None: results_file = os.path.join(out_dir, 'results.csv') else: results_file = os.path.expanduser(results_file) fig_dir = os.path.join(out_dir, 'figures') if not os.path.exists(fig_dir): os.makedirs(fig_dir) out_data_dir = os.path.join(out_dir, 'data') if not os.path.exists(out_data_dir): os.makedirs(out_data_dir) def saveVariable(var, var_name): joblib.dump(var, os.path.join(out_data_dir, f'{var_name}.pkl')) def loadAll(seq_ids, var_name, data_dir): def loadOne(seq_id): fn = os.path.join(data_dir, f'trial={seq_id}_{var_name}') return joblib.load(fn) return tuple(map(loadOne, seq_ids)) device = torchutils.selectDevice(gpu_dev_id) # Load data dir_trial_ids = tuple( set(utils.getUniqueIds(d, prefix='trial=', to_array=True)) for d in score_dirs) dir_trial_ids += (set( utils.getUniqueIds(data_dir, prefix='trial=', to_array=True)), ) trial_ids = np.array(list(sorted(set.intersection(*dir_trial_ids)))) for dir_name, t_ids in zip(score_dirs + (data_dir, ), dir_trial_ids): logger.info(f"{len(t_ids)} trial ids from {dir_name}:") logger.info(f" {t_ids}") logger.info(f"{len(trial_ids)} trials in intersection: {trial_ids}") assembly_seqs = loadAll(trial_ids, 'assembly-seq.pkl', data_dir) feature_seqs = tuple( loadAll(trial_ids, 'data-scores.pkl', d) for d in score_dirs) feature_seqs = tuple(zip(*feature_seqs)) # Combine feature seqs include_indices = [] for i, seq_feats in enumerate(feature_seqs): feat_shapes = tuple(f.shape for f in seq_feats) include_seq = all(f == feat_shapes[0] for f in feat_shapes) if include_seq: include_indices.append(i) else: warn_str = ( f'Excluding trial {trial_ids[i]} with mismatched feature shapes: ' f'{feat_shapes}') logger.warning(warn_str) trial_ids = trial_ids[include_indices] assembly_seqs = tuple(assembly_seqs[i] for i in include_indices) feature_seqs = tuple(feature_seqs[i] for i in include_indices) feature_seqs = tuple(np.stack(f) for f in feature_seqs) # Define cross-validation folds if cv_data_dir is None: dataset_size = len(trial_ids) cv_folds = utils.makeDataSplits(dataset_size, **cv_params) cv_fold_trial_ids = tuple( tuple(map(lambda x: trial_ids[x], splits)) for splits in cv_folds) else: fn = os.path.join(cv_data_dir, 'cv-fold-trial-ids.pkl') cv_fold_trial_ids = joblib.load(fn) def getSplit(split_idxs): split_data = tuple( tuple(s[i] for i in split_idxs) for s in (feature_seqs, assembly_seqs, trial_ids)) return split_data gt_scores = [] all_scores = [] num_keyframes_total = 0 num_rgb_errors_total = 0 num_correctable_errors_total = 0 num_oov_total = 0 num_changed_total = 0 for cv_index, (train_ids, test_ids) in enumerate(cv_fold_trial_ids): try: test_idxs = np.array( [trial_ids.tolist().index(i) for i in test_ids]) include_indices.append(cv_index) except ValueError: logger.info( f" Skipping fold {cv_index}: missing test data {test_ids}") logger.info(f'CV fold {cv_index + 1}: {len(trial_ids)} total ' f'({len(train_ids)} train, {len(test_ids)} test)') # TRAIN PHASE if cv_data_dir is None: train_idxs = np.array([trial_ids.index(i) for i in train_ids]) train_assembly_seqs = tuple(assembly_seqs[i] for i in train_idxs) train_assemblies = [] for seq in train_assembly_seqs: list( labels.gen_eq_classes(seq, train_assemblies, equivalent=None)) model = None else: fn = f'cvfold={cv_index}_train-assemblies.pkl' train_assemblies = joblib.load(os.path.join(cv_data_dir, fn)) train_idxs = [ i for i in range(len(trial_ids)) if i not in test_idxs ] fn = f'cvfold={cv_index}_model.pkl' model = joblib.load(os.path.join(cv_data_dir, fn)) train_features, train_assembly_seqs, train_ids = getSplit(train_idxs) if False: train_labels = tuple( np.array( list( labels.gen_eq_classes(assembly_seq, train_assemblies, equivalent=None)), ) for assembly_seq in train_assembly_seqs) train_set = torchutils.SequenceDataset(train_features, train_labels, seq_ids=train_ids, device=device) train_loader = torch.utils.data.DataLoader(train_set, batch_size=1, shuffle=True) train_epoch_log = collections.defaultdict(list) # val_epoch_log = collections.defaultdict(list) metric_dict = { 'Avg Loss': metrics.AverageLoss(), 'Accuracy': metrics.Accuracy() } criterion = torch.nn.CrossEntropyLoss() optimizer_ft = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft, step_size=1, gamma=1.00) model = FusionClassifier(num_sources=train_features[0].shape[0]) model, last_model_wts = torchutils.trainModel( model, criterion, optimizer_ft, lr_scheduler, train_loader, # val_loader, device=device, metrics=metric_dict, train_epoch_log=train_epoch_log, # val_epoch_log=val_epoch_log, **train_params) torchutils.plotEpochLog(train_epoch_log, subfig_size=(10, 2.5), title='Training performance', fn=os.path.join( fig_dir, f'cvfold={cv_index}_train-plot.png')) test_assemblies = train_assemblies.copy() for feature_seq, gt_assembly_seq, trial_id in zip( *getSplit(test_idxs)): gt_seq = np.array( list( labels.gen_eq_classes(gt_assembly_seq, test_assemblies, equivalent=None))) if plot_predictions: assembly_fig_dir = os.path.join(fig_dir, 'assembly-imgs') if not os.path.exists(assembly_fig_dir): os.makedirs(assembly_fig_dir) for i, assembly in enumerate(test_assemblies): assembly.draw(assembly_fig_dir, i) # TEST PHASE accuracies = [] for feature_seq, gt_assembly_seq, trial_id in zip( *getSplit(test_idxs)): gt_seq = np.array( list( labels.gen_eq_classes(gt_assembly_seq, test_assemblies, equivalent=None))) num_labels = gt_seq.shape[0] num_features = feature_seq.shape[-1] if num_labels != num_features: err_str = (f"Skipping trial {trial_id}: " f"{num_labels} labels != {num_features} features") logger.info(err_str) continue # Ignore OOV states in ground-truth sample_idxs = np.arange(feature_seq.shape[-1]) score_idxs = gt_seq[gt_seq < feature_seq.shape[1]] sample_idxs = sample_idxs[gt_seq < feature_seq.shape[1]] gt_scores.append(feature_seq[:, score_idxs, sample_idxs]) all_scores.append(feature_seq.reshape(feature_seq.shape[0], -1)) if fusion_method == 'sum': score_seq = feature_seq.sum(axis=0) elif fusion_method == 'rgb_only': score_seq = feature_seq[1] elif fusion_method == 'imu_only': score_seq = feature_seq[0] else: raise NotImplementedError() if not decode: model = None if model is None: pred_seq = score_seq.argmax(axis=0) elif isinstance(model, torch.nn.Module): inputs = torch.tensor(feature_seq[None, ...], dtype=torch.float, device=device) outputs = model.forward(inputs) pred_seq = model.predict(outputs)[0].cpu().numpy() else: dummy_samples = np.arange(score_seq.shape[1]) pred_seq, _, _, _ = model.viterbi(dummy_samples, log_likelihoods=score_seq, ml_decode=(not decode)) pred_assemblies = [train_assemblies[i] for i in pred_seq] gt_assemblies = [test_assemblies[i] for i in gt_seq] acc = metrics.accuracy_upto(pred_assemblies, gt_assemblies, equivalence=None) accuracies.append(acc) rgb_pred_seq = feature_seq[1].argmax(axis=0) num_changed = np.sum(rgb_pred_seq != pred_seq) rgb_is_wrong = rgb_pred_seq != gt_seq num_rgb_errors = np.sum(rgb_is_wrong) imu_scores = feature_seq[0] imu_scores_gt = np.array([ imu_scores[s_idx, t] if s_idx < imu_scores.shape[0] else -np.inf for t, s_idx in enumerate(gt_seq) ]) imu_scores_rgb = np.array([ imu_scores[s_idx, t] if s_idx < imu_scores.shape[0] else -np.inf for t, s_idx in enumerate(rgb_pred_seq) ]) # imu_scores_gt = imu_scores[gt_seq, range(len(rgb_pred_seq))] best_imu_scores = imu_scores.max(axis=0) imu_is_right = imu_scores_gt >= best_imu_scores rgb_pred_score_is_lower = imu_scores_gt > imu_scores_rgb is_correctable_error = rgb_is_wrong & imu_is_right & rgb_pred_score_is_lower num_correctable_errors = np.sum(is_correctable_error) prop_correctable = num_correctable_errors / num_rgb_errors num_oov = np.sum(gt_seq >= len(train_assemblies)) num_states = len(gt_seq) num_keyframes_total += num_states num_rgb_errors_total += num_rgb_errors num_correctable_errors_total += num_correctable_errors num_oov_total += num_oov num_changed_total += num_changed logger.info(f" trial {trial_id}: {num_states} keyframes") logger.info(f" accuracy (fused): {acc * 100:.1f}%") logger.info( f" {num_oov} OOV states ({num_oov / num_states * 100:.1f}%)" ) logger.info( f" {num_rgb_errors} RGB errors; " f"{num_correctable_errors} correctable from IMU ({prop_correctable * 100:.1f}%)" ) saveVariable(score_seq, f'trial={trial_id}_data-scores') saveVariable(pred_assemblies, f'trial={trial_id}_pred-assembly-seq') saveVariable(gt_assemblies, f'trial={trial_id}_gt-assembly-seq') if plot_predictions: io_figs_dir = os.path.join(fig_dir, 'system-io') if not os.path.exists(io_figs_dir): os.makedirs(io_figs_dir) fn = os.path.join(io_figs_dir, f'trial={trial_id:03}.png') utils.plot_array(feature_seq, (gt_seq, pred_seq, score_seq), ('gt', 'pred', 'scores'), fn=fn) score_figs_dir = os.path.join(fig_dir, 'modality-scores') if not os.path.exists(score_figs_dir): os.makedirs(score_figs_dir) plot_scores(feature_seq, k=25, fn=os.path.join(score_figs_dir, f"trial={trial_id:03}.png")) paths_dir = os.path.join(fig_dir, 'path-imgs') if not os.path.exists(paths_dir): os.makedirs(paths_dir) assemblystats.drawPath(pred_seq, trial_id, f"trial={trial_id}_pred-seq", paths_dir, assembly_fig_dir) assemblystats.drawPath(gt_seq, trial_id, f"trial={trial_id}_gt-seq", paths_dir, assembly_fig_dir) label_seqs = (gt_seq, ) + tuple( scores.argmax(axis=0) for scores in feature_seq) label_seqs = np.row_stack(label_seqs) k = 10 for i, scores in enumerate(feature_seq): label_score_seqs = tuple( np.array([ scores[s_idx, t] if s_idx < scores.shape[0] else -np.inf for t, s_idx in enumerate(label_seq) ]) for label_seq in label_seqs) label_score_seqs = np.row_stack(label_score_seqs) drawPaths(label_seqs, f"trial={trial_id}_pred-scores_modality={i}", paths_dir, assembly_fig_dir, path_scores=label_score_seqs) topk_seq = (-scores).argsort(axis=0)[:k, :] path_scores = np.column_stack( tuple(scores[idxs, i] for i, idxs in enumerate(topk_seq.T))) drawPaths(topk_seq, f"trial={trial_id}_topk_modality={i}", paths_dir, assembly_fig_dir, path_scores=path_scores) label_score_seqs = tuple( np.array([ score_seq[s_idx, t] if s_idx < score_seq.shape[0] else -np.inf for t, s_idx in enumerate(label_seq) ]) for label_seq in label_seqs) label_score_seqs = np.row_stack(label_score_seqs) drawPaths(label_seqs, f"trial={trial_id}_pred-scores_fused", paths_dir, assembly_fig_dir, path_scores=label_score_seqs) topk_seq = (-score_seq).argsort(axis=0)[:k, :] path_scores = np.column_stack( tuple(score_seq[idxs, i] for i, idxs in enumerate(topk_seq.T))) drawPaths(topk_seq, f"trial={trial_id}_topk_fused", paths_dir, assembly_fig_dir, path_scores=path_scores) if accuracies: fold_accuracy = float(np.array(accuracies).mean()) # logger.info(f' acc: {fold_accuracy * 100:.1f}%') metric_dict = {'Accuracy': fold_accuracy} utils.writeResults(results_file, metric_dict, sweep_param_name, model_params) num_unexplained_errors = num_rgb_errors_total - ( num_oov_total + num_correctable_errors_total) prop_correctable = num_correctable_errors_total / num_rgb_errors_total prop_oov = num_oov_total / num_rgb_errors_total prop_unexplained = num_unexplained_errors / num_rgb_errors_total prop_changed = num_changed_total / num_keyframes_total logger.info("PERFORMANCE ANALYSIS") logger.info( f" {num_rgb_errors_total} / {num_keyframes_total} " f"RGB errors ({num_rgb_errors_total / num_keyframes_total * 100:.1f}%)" ) logger.info(f" {num_oov_total} / {num_rgb_errors_total} " f"RGB errors are OOV ({prop_oov * 100:.1f}%)") logger.info(f" {num_correctable_errors_total} / {num_rgb_errors_total} " f"RGB errors are correctable ({prop_correctable * 100:.1f}%)") logger.info(f" {num_unexplained_errors} / {num_rgb_errors_total} " f"RGB errors are unexplained ({prop_unexplained * 100:.1f}%)") logger.info( f" {num_changed_total} / {num_keyframes_total} " f"Predictions changed after fusion ({prop_changed * 100:.1f}%)") gt_scores = np.hstack(tuple(gt_scores)) plot_hists(np.exp(gt_scores), fn=os.path.join(fig_dir, "score-hists_gt.png")) all_scores = np.hstack(tuple(all_scores)) plot_hists(np.exp(all_scores), fn=os.path.join(fig_dir, "score-hists_all.png"))
def main(out_dir=None, data_dir=None, attr_dir=None, model_name=None, gpu_dev_id=None, batch_size=None, learning_rate=None, model_params={}, cv_params={}, train_params={}, viz_params={}, plot_predictions=None, results_file=None, sweep_param_name=None): data_dir = os.path.expanduser(data_dir) out_dir = os.path.expanduser(out_dir) attr_dir = os.path.expanduser(attr_dir) if not os.path.exists(out_dir): os.makedirs(out_dir) logger = utils.setupRootLogger(filename=os.path.join(out_dir, 'log.txt')) if results_file is None: results_file = os.path.join(out_dir, f'results.csv') else: results_file = os.path.expanduser(results_file) fig_dir = os.path.join(out_dir, 'figures') if not os.path.exists(fig_dir): os.makedirs(fig_dir) out_data_dir = os.path.join(out_dir, 'data') if not os.path.exists(out_data_dir): os.makedirs(out_data_dir) def loadData(seq_id): var_name = f"trial-{seq_id}_rgb-frame-seq" data = joblib.load(os.path.join(data_dir, f'{var_name}.pkl')) return data.swapaxes(1, 3) def loadLabels(seq_id): var_name = f"trial-{seq_id}_label-seq" return joblib.load(os.path.join(attr_dir, f'{var_name}.pkl')) def saveVariable(var, var_name): joblib.dump(var, os.path.join(out_data_dir, f'{var_name}.pkl')) # Load data trial_ids = utils.getUniqueIds(data_dir) label_seqs = tuple(map(loadLabels, trial_ids)) device = torchutils.selectDevice(gpu_dev_id) # Define cross-validation folds dataset_size = len(trial_ids) cv_folds = utils.makeDataSplits(dataset_size, **cv_params) def getSplit(split_idxs): split_data = tuple( tuple(s[i] for i in split_idxs) for s in (label_seqs, trial_ids)) return split_data for cv_index, cv_splits in enumerate(cv_folds): train_data, val_data, test_data = tuple(map(getSplit, cv_splits)) criterion = torch.nn.BCEWithLogitsLoss() labels_dtype = torch.float train_labels, train_ids = train_data train_set = torchutils.PickledVideoDataset(loadData, train_labels, device=device, labels_dtype=labels_dtype, seq_ids=train_ids, batch_size=batch_size) train_loader = torch.utils.data.DataLoader(train_set, batch_size=1, shuffle=True) test_labels, test_ids = test_data test_set = torchutils.PickledVideoDataset(loadData, test_labels, device=device, labels_dtype=labels_dtype, seq_ids=test_ids, batch_size=batch_size) test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=False) val_labels, val_ids = val_data val_set = torchutils.PickledVideoDataset(loadData, val_labels, device=device, labels_dtype=labels_dtype, seq_ids=val_ids, batch_size=batch_size) val_loader = torch.utils.data.DataLoader(val_set, batch_size=1, shuffle=True) logger.info( f'CV fold {cv_index + 1} / {len(cv_folds)}: {len(trial_ids)} total ' f'({len(train_ids)} train, {len(val_ids)} val, {len(test_ids)} test)' ) if model_name == 'resnet': # input_dim = train_set.num_obsv_dims output_dim = train_set.num_label_types model = ImageClassifier(output_dim, **model_params).to(device=device) else: raise AssertionError() train_epoch_log = collections.defaultdict(list) val_epoch_log = collections.defaultdict(list) metric_dict = { 'Avg Loss': metrics.AverageLoss(), 'Accuracy': metrics.Accuracy(), 'Precision': metrics.Precision(), 'Recall': metrics.Recall(), 'F1': metrics.Fmeasure() } optimizer_ft = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft, step_size=1, gamma=1.00) model, last_model_wts = torchutils.trainModel( model, criterion, optimizer_ft, lr_scheduler, train_loader, val_loader, device=device, metrics=metric_dict, train_epoch_log=train_epoch_log, val_epoch_log=val_epoch_log, **train_params) # Test model metric_dict = { 'Avg Loss': metrics.AverageLoss(), 'Accuracy': metrics.Accuracy(), 'Precision': metrics.Precision(), 'Recall': metrics.Recall(), 'F1': metrics.Fmeasure() } test_io_history = torchutils.predictSamples( model.to(device=device), test_loader, criterion=criterion, device=device, metrics=metric_dict, data_labeled=True, update_model=False, seq_as_batch=train_params['seq_as_batch'], return_io_history=True) metric_str = ' '.join(str(m) for m in metric_dict.values()) logger.info('[TST] ' + metric_str) utils.writeResults(results_file, metric_dict, sweep_param_name, model_params) if plot_predictions: # imu.plot_prediction_eg(test_io_history, fig_dir, fig_type=fig_type, **viz_params) imu.plot_prediction_eg(test_io_history, fig_dir, **viz_params) def saveTrialData(pred_seq, score_seq, feat_seq, label_seq, trial_id): saveVariable(pred_seq.cpu().numpy(), f'trial={trial_id}_pred-label-seq') saveVariable(score_seq.cpu().numpy(), f'trial={trial_id}_score-seq') saveVariable(label_seq.cpu().numpy(), f'trial={trial_id}_true-label-seq') for io in test_io_history: saveTrialData(*io) saveVariable(train_ids, f'cvfold={cv_index}_train-ids') saveVariable(test_ids, f'cvfold={cv_index}_test-ids') saveVariable(val_ids, f'cvfold={cv_index}_val-ids') saveVariable(train_epoch_log, f'cvfold={cv_index}_{model_name}-train-epoch-log') saveVariable(val_epoch_log, f'cvfold={cv_index}_{model_name}-val-epoch-log') saveVariable(metric_dict, f'cvfold={cv_index}_{model_name}-metric-dict') saveVariable(model, f'cvfold={cv_index}_{model_name}-best') model.load_state_dict(last_model_wts) saveVariable(model, f'cvfold={cv_index}_{model_name}-last') torchutils.plotEpochLog(train_epoch_log, subfig_size=(10, 2.5), title='Training performance', fn=os.path.join( fig_dir, f'cvfold={cv_index}_train-plot.png')) if val_epoch_log: torchutils.plotEpochLog(val_epoch_log, subfig_size=(10, 2.5), title='Heldout performance', fn=os.path.join( fig_dir, f'cvfold={cv_index}_val-plot.png'))
def main(out_dir=None, data_dir=None, model_name=None, pretrained_model_dir=None, gpu_dev_id=None, batch_size=None, learning_rate=None, independent_signals=None, active_only=None, model_params={}, cv_params={}, train_params={}, viz_params={}, plot_predictions=None, results_file=None, sweep_param_name=None, label_mapping=None, eval_label_mapping=None): data_dir = os.path.expanduser(data_dir) out_dir = os.path.expanduser(out_dir) if not os.path.exists(out_dir): os.makedirs(out_dir) logger = utils.setupRootLogger(filename=os.path.join(out_dir, 'log.txt')) if results_file is None: results_file = os.path.join(out_dir, 'results.csv') else: results_file = os.path.expanduser(results_file) fig_dir = os.path.join(out_dir, 'figures') if not os.path.exists(fig_dir): os.makedirs(fig_dir) out_data_dir = os.path.join(out_dir, 'data') if not os.path.exists(out_data_dir): os.makedirs(out_data_dir) def saveVariable(var, var_name): joblib.dump(var, os.path.join(out_data_dir, f'{var_name}.pkl')) def loadAll(seq_ids, var_name, data_dir): def loadOne(seq_id): fn = os.path.join(data_dir, f'trial={seq_id}_{var_name}') return joblib.load(fn) return tuple(map(loadOne, seq_ids)) # Load data trial_ids = utils.getUniqueIds(data_dir, prefix='trial=', to_array=True) feature_seqs = loadAll(trial_ids, 'feature-seq.pkl', data_dir) label_seqs = loadAll(trial_ids, 'label-seq.pkl', data_dir) device = torchutils.selectDevice(gpu_dev_id) if label_mapping is not None: def map_labels(labels): for i, j in label_mapping.items(): labels[labels == i] = j return labels label_seqs = tuple(map(map_labels, label_seqs)) # Define cross-validation folds dataset_size = len(trial_ids) cv_folds = utils.makeDataSplits(dataset_size, **cv_params) def getSplit(split_idxs): split_data = tuple( tuple(s[i] for i in split_idxs) for s in (feature_seqs, label_seqs, trial_ids)) return split_data for cv_index, cv_splits in enumerate(cv_folds): if pretrained_model_dir is not None: def loadFromPretrain(fn): return joblib.load( os.path.join(pretrained_model_dir, f"{fn}.pkl")) model = loadFromPretrain(f'cvfold={cv_index}_{model_name}-best') train_ids = loadFromPretrain(f'cvfold={cv_index}_train-ids') val_ids = loadFromPretrain(f'cvfold={cv_index}_val-ids') test_ids = tuple(i for i in trial_ids if i not in (train_ids + val_ids)) test_idxs = tuple(trial_ids.tolist().index(i) for i in test_ids) test_data = getSplit(test_idxs) if independent_signals: criterion = torch.nn.CrossEntropyLoss() labels_dtype = torch.long test_data = splitSeqs(*test_data, active_only=False) else: # FIXME # criterion = torch.nn.BCEWithLogitsLoss() # labels_dtype = torch.float criterion = torch.nn.CrossEntropyLoss() labels_dtype = torch.long test_feats, test_labels, test_ids = test_data test_set = torchutils.SequenceDataset(test_feats, test_labels, device=device, labels_dtype=labels_dtype, seq_ids=test_ids, transpose_data=True) test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False) # Test model metric_dict = { 'Avg Loss': metrics.AverageLoss(), 'Accuracy': metrics.Accuracy(), 'Precision': metrics.Precision(), 'Recall': metrics.Recall(), 'F1': metrics.Fmeasure() } test_io_history = torchutils.predictSamples( model.to(device=device), test_loader, criterion=criterion, device=device, metrics=metric_dict, data_labeled=True, update_model=False, seq_as_batch=train_params['seq_as_batch'], return_io_history=True) if independent_signals: test_io_history = tuple(joinSeqs(test_io_history)) metric_str = ' '.join(str(m) for m in metric_dict.values()) logger.info('[TST] ' + metric_str) d = {k: v.value for k, v in metric_dict.items()} utils.writeResults(results_file, d, sweep_param_name, model_params) if plot_predictions: imu.plot_prediction_eg(test_io_history, fig_dir, **viz_params) def saveTrialData(pred_seq, score_seq, feat_seq, label_seq, trial_id): if label_mapping is not None: def dup_score_cols(scores): num_cols = scores.shape[-1] + len(label_mapping) col_idxs = torch.arange(num_cols) for i, j in label_mapping.items(): col_idxs[i] = j return scores[..., col_idxs] score_seq = dup_score_cols(score_seq) saveVariable(pred_seq.cpu().numpy(), f'trial={trial_id}_pred-label-seq') saveVariable(score_seq.cpu().numpy(), f'trial={trial_id}_score-seq') saveVariable(label_seq.cpu().numpy(), f'trial={trial_id}_true-label-seq') for io in test_io_history: saveTrialData(*io) continue train_data, val_data, test_data = tuple(map(getSplit, cv_splits)) if independent_signals: criterion = torch.nn.CrossEntropyLoss() labels_dtype = torch.long split_ = functools.partial(splitSeqs, active_only=active_only) train_data = split_(*train_data) val_data = split_(*val_data) test_data = splitSeqs(*test_data, active_only=False) else: # FIXME # criterion = torch.nn.BCEWithLogitsLoss() # labels_dtype = torch.float criterion = torch.nn.CrossEntropyLoss() labels_dtype = torch.long train_feats, train_labels, train_ids = train_data train_set = torchutils.SequenceDataset(train_feats, train_labels, device=device, labels_dtype=labels_dtype, seq_ids=train_ids, transpose_data=True) train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True) test_feats, test_labels, test_ids = test_data test_set = torchutils.SequenceDataset(test_feats, test_labels, device=device, labels_dtype=labels_dtype, seq_ids=test_ids, transpose_data=True) test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False) val_feats, val_labels, val_ids = val_data val_set = torchutils.SequenceDataset(val_feats, val_labels, device=device, labels_dtype=labels_dtype, seq_ids=val_ids, transpose_data=True) val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=True) logger.info( f'CV fold {cv_index + 1} / {len(cv_folds)}: {len(trial_ids)} total ' f'({len(train_ids)} train, {len(val_ids)} val, {len(test_ids)} test)' ) input_dim = train_set.num_obsv_dims output_dim = train_set.num_label_types if model_name == 'linear': model = torchutils.LinearClassifier( input_dim, output_dim, **model_params).to(device=device) elif model_name == 'conv': model = ConvClassifier(input_dim, output_dim, **model_params).to(device=device) elif model_name == 'TCN': model = TcnClassifier(input_dim, output_dim, **model_params) else: raise AssertionError() train_epoch_log = collections.defaultdict(list) val_epoch_log = collections.defaultdict(list) metric_dict = { 'Avg Loss': metrics.AverageLoss(), 'Accuracy': metrics.Accuracy(), 'Precision': metrics.Precision(), 'Recall': metrics.Recall(), 'F1': metrics.Fmeasure() } optimizer_ft = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft, step_size=1, gamma=1.00) model, last_model_wts = torchutils.trainModel( model, criterion, optimizer_ft, lr_scheduler, train_loader, val_loader, device=device, metrics=metric_dict, train_epoch_log=train_epoch_log, val_epoch_log=val_epoch_log, **train_params) # Test model metric_dict = { 'Avg Loss': metrics.AverageLoss(), 'Accuracy': metrics.Accuracy(), 'Precision': metrics.Precision(), 'Recall': metrics.Recall(), 'F1': metrics.Fmeasure() } test_io_history = torchutils.predictSamples( model.to(device=device), test_loader, criterion=criterion, device=device, metrics=metric_dict, data_labeled=True, update_model=False, seq_as_batch=train_params['seq_as_batch'], return_io_history=True) if independent_signals: test_io_history = tuple(joinSeqs(test_io_history)) metric_str = ' '.join(str(m) for m in metric_dict.values()) logger.info('[TST] ' + metric_str) d = {k: v.value for k, v in metric_dict.items()} utils.writeResults(results_file, d, sweep_param_name, model_params) if plot_predictions: # imu.plot_prediction_eg(test_io_history, fig_dir, fig_type=fig_type, **viz_params) imu.plot_prediction_eg(test_io_history, fig_dir, **viz_params) def saveTrialData(pred_seq, score_seq, feat_seq, label_seq, trial_id): if label_mapping is not None: def dup_score_cols(scores): num_cols = scores.shape[-1] + len(label_mapping) col_idxs = torch.arange(num_cols) for i, j in label_mapping.items(): col_idxs[i] = j return scores[..., col_idxs] score_seq = dup_score_cols(score_seq) saveVariable(pred_seq.cpu().numpy(), f'trial={trial_id}_pred-label-seq') saveVariable(score_seq.cpu().numpy(), f'trial={trial_id}_score-seq') saveVariable(label_seq.cpu().numpy(), f'trial={trial_id}_true-label-seq') for io in test_io_history: saveTrialData(*io) saveVariable(train_ids, f'cvfold={cv_index}_train-ids') saveVariable(test_ids, f'cvfold={cv_index}_test-ids') saveVariable(val_ids, f'cvfold={cv_index}_val-ids') saveVariable(train_epoch_log, f'cvfold={cv_index}_{model_name}-train-epoch-log') saveVariable(val_epoch_log, f'cvfold={cv_index}_{model_name}-val-epoch-log') saveVariable(metric_dict, f'cvfold={cv_index}_{model_name}-metric-dict') saveVariable(model, f'cvfold={cv_index}_{model_name}-best') model.load_state_dict(last_model_wts) saveVariable(model, f'cvfold={cv_index}_{model_name}-last') torchutils.plotEpochLog(train_epoch_log, subfig_size=(10, 2.5), title='Training performance', fn=os.path.join( fig_dir, f'cvfold={cv_index}_train-plot.png')) if val_epoch_log: torchutils.plotEpochLog(val_epoch_log, subfig_size=(10, 2.5), title='Heldout performance', fn=os.path.join( fig_dir, f'cvfold={cv_index}_val-plot.png')) if eval_label_mapping is not None: metric_dict = { 'Avg Loss': metrics.AverageLoss(), 'Accuracy': metrics.Accuracy(), 'Precision': metrics.Precision(), 'Recall': metrics.Recall(), 'F1': metrics.Fmeasure() } test_io_history = torchutils.predictSamples( model.to(device=device), test_loader, criterion=criterion, device=device, metrics=metric_dict, data_labeled=True, update_model=False, seq_as_batch=train_params['seq_as_batch'], return_io_history=True, label_mapping=eval_label_mapping) if independent_signals: test_io_history = joinSeqs(test_io_history) metric_str = ' '.join(str(m) for m in metric_dict.values()) logger.info('[TST] ' + metric_str)
def main(out_dir=None, gpu_dev_id=None, num_samples=10, random_seed=None, learning_rate=1e-3, num_epochs=500, dataset_kwargs={}, dataloader_kwargs={}, model_kwargs={}): if out_dir is None: out_dir = os.path.join('~', 'data', 'output', 'seqtools', 'test_gtn') out_dir = os.path.expanduser(out_dir) if not os.path.exists(out_dir): os.makedirs(out_dir) fig_dir = os.path.join(out_dir, 'figures') if not os.path.exists(fig_dir): os.makedirs(fig_dir) vocabulary = ['a', 'b', 'c', 'd', 'e'] transition = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 1, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 1], [0, 0, 0, 0, 0]], dtype=float) initial = np.array([1, 0, 1, 0, 0], dtype=float) final = np.array([0, 1, 0, 0, 1], dtype=float) / 10 seq_params = (transition, initial, final) simulated_dataset = simulate(num_samples, *seq_params) label_seqs, obsv_seqs = tuple(zip(*simulated_dataset)) seq_params = tuple(map(lambda x: -np.log(x), seq_params)) dataset = torchutils.SequenceDataset(obsv_seqs, label_seqs, **dataset_kwargs) data_loader = torch.utils.data.DataLoader(dataset, **dataloader_kwargs) train_loader = data_loader val_loader = data_loader transition_weights = torch.tensor(transition, dtype=torch.float).log() initial_weights = torch.tensor(initial, dtype=torch.float).log() final_weights = torch.tensor(final, dtype=torch.float).log() model = libfst.LatticeCrf(vocabulary, transition_weights=transition_weights, initial_weights=initial_weights, final_weights=final_weights, debug_output_dir=fig_dir, **model_kwargs) gtn.draw(model._transition_fst, os.path.join(fig_dir, 'transitions-init.png'), isymbols=model._arc_symbols, osymbols=model._arc_symbols) gtn.draw(model._duration_fst, os.path.join(fig_dir, 'durations-init.png'), isymbols=model._arc_symbols, osymbols=model._arc_symbols) if True: for i, (inputs, targets, seq_id) in enumerate(train_loader): arc_scores = model.scores_to_arc(inputs) arc_labels = model.labels_to_arc(targets) batch_size, num_samples, num_classes = arc_scores.shape obs_fst = libfst.linearFstFromArray(arc_scores[0].reshape( num_samples, -1)) gt_fst = libfst.fromSequence(arc_labels[0]) d1_fst = gtn.compose(obs_fst, model._duration_fst) d1_fst = gtn.project_output(d1_fst) denom_fst = gtn.compose(d1_fst, model._transition_fst) # denom_fst = gtn.project_output(denom_fst) num_fst = gtn.compose(denom_fst, gt_fst) viterbi_fst = gtn.viterbi_path(denom_fst) pred_fst = gtn.remove(gtn.project_output(viterbi_fst)) loss = gtn.subtract(gtn.forward_score(num_fst), gtn.forward_score(denom_fst)) loss = torch.tensor(loss.item()) if torch.isinf(loss).any(): denom_alt = gtn.compose(obs_fst, model._transition_fst) d1_min = gtn.remove(gtn.project_output(d1_fst)) denom_alt = gtn.compose(d1_min, model._transition_fst) num_alt = gtn.compose(denom_alt, gt_fst) gtn.draw(obs_fst, os.path.join(fig_dir, 'observations-init.png'), isymbols=model._arc_symbols, osymbols=model._arc_symbols) gtn.draw(gt_fst, os.path.join(fig_dir, 'labels-init.png'), isymbols=model._arc_symbols, osymbols=model._arc_symbols) gtn.draw(d1_fst, os.path.join(fig_dir, 'd1-init.png'), isymbols=model._arc_symbols, osymbols=model._arc_symbols) gtn.draw(d1_min, os.path.join(fig_dir, 'd1-min-init.png'), isymbols=model._arc_symbols, osymbols=model._arc_symbols) gtn.draw(denom_fst, os.path.join(fig_dir, 'denominator-init.png'), isymbols=model._arc_symbols, osymbols=model._arc_symbols) gtn.draw(denom_alt, os.path.join(fig_dir, 'denominator-alt-init.png'), isymbols=model._arc_symbols, osymbols=model._arc_symbols) gtn.draw(num_fst, os.path.join(fig_dir, 'numerator-init.png'), isymbols=model._arc_symbols, osymbols=model._arc_symbols) gtn.draw(num_alt, os.path.join(fig_dir, 'numerator-alt-init.png'), isymbols=model._arc_symbols, osymbols=model._arc_symbols) gtn.draw(viterbi_fst, os.path.join(fig_dir, 'viterbi-init.png'), isymbols=model._arc_symbols, osymbols=model._arc_symbols) gtn.draw(pred_fst, os.path.join(fig_dir, 'pred-init.png'), isymbols=model._arc_symbols, osymbols=model._arc_symbols) import pdb pdb.set_trace() # Train the model train_epoch_log = collections.defaultdict(list) val_epoch_log = collections.defaultdict(list) metric_dict = { 'Avg Loss': metrics.AverageLoss(), 'Accuracy': metrics.Accuracy() } criterion = model.nllLoss optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=1.00) model, last_model_wts = torchutils.trainModel( model, criterion, optimizer, scheduler, train_loader, val_loader, metrics=metric_dict, test_metric='Avg Loss', train_epoch_log=train_epoch_log, val_epoch_log=val_epoch_log, num_epochs=num_epochs) gtn.draw(model._transition_fst, os.path.join(fig_dir, 'transitions-trained.png'), isymbols=model._arc_symbols, osymbols=model._arc_symbols) gtn.draw(model._duration_fst, os.path.join(fig_dir, 'durations-trained.png'), isymbols=model._arc_symbols, osymbols=model._arc_symbols) torchutils.plotEpochLog(train_epoch_log, title="Train Epoch Log", fn=os.path.join(fig_dir, "train-log.png"))
def main(out_dir=None, data_dir=None, model_name=None, part_symmetries=None, gpu_dev_id=None, batch_size=None, learning_rate=None, model_params={}, cv_params={}, train_params={}, viz_params={}, plot_predictions=None, results_file=None, sweep_param_name=None): if part_symmetries is None: part_symmetries = { 'beam_side': ('backbeam_hole_1', 'backbeam_hole_2', 'frontbeam_hole_1', 'frontbeam_hole_2'), 'beam_top': ('backbeam_hole_3', 'frontbeam_hole_3'), 'backrest': ('backrest_hole_1', 'backrest_hole_2') } data_dir = os.path.expanduser(data_dir) out_dir = os.path.expanduser(out_dir) if not os.path.exists(out_dir): os.makedirs(out_dir) logger = utils.setupRootLogger(filename=os.path.join(out_dir, 'log.txt')) if results_file is None: results_file = os.path.join(out_dir, 'results.csv') else: results_file = os.path.expanduser(results_file) fig_dir = os.path.join(out_dir, 'figures') if not os.path.exists(fig_dir): os.makedirs(fig_dir) out_data_dir = os.path.join(out_dir, 'data') if not os.path.exists(out_data_dir): os.makedirs(out_data_dir) def saveVariable(var, var_name): joblib.dump(var, os.path.join(out_data_dir, f'{var_name}.pkl')) def loadAll(seq_ids, var_name, data_dir): def loadOne(seq_id): fn = os.path.join(data_dir, f'trial={seq_id}_{var_name}') return joblib.load(fn) return tuple(map(loadOne, seq_ids)) # Load vocab with open(os.path.join(data_dir, "part-vocab.yaml"), 'rt') as f: link_vocab = yaml.safe_load(f) assembly_vocab = joblib.load(os.path.join(data_dir, 'assembly-vocab.pkl')) # Load data trial_ids = utils.getUniqueIds(data_dir, prefix='trial=') feature_seqs = loadAll(trial_ids, 'feature-seq.pkl', data_dir) label_seqs = loadAll(trial_ids, 'label-seq.pkl', data_dir) if part_symmetries: # Construct equivalence classes from vocab eq_classes, assembly_eq_classes, eq_class_vocab = makeEqClasses( assembly_vocab, part_symmetries) lib_assembly.writeAssemblies( os.path.join(fig_dir, 'eq-class-vocab.txt'), eq_class_vocab) label_seqs = tuple(assembly_eq_classes[label_seq] for label_seq in label_seqs) saveVariable(eq_class_vocab, 'assembly-vocab') else: eq_classes = None def impute_nan(input_seq): input_is_nan = np.isnan(input_seq) logger.info(f"{input_is_nan.sum()} NaN elements") input_seq[input_is_nan] = 0 # np.nanmean(input_seq) return input_seq # feature_seqs = tuple(map(impute_nan, feature_seqs)) for trial_id, label_seq, feat_seq in zip(trial_ids, label_seqs, feature_seqs): saveVariable(feat_seq, f"trial={trial_id}_feature-seq") saveVariable(label_seq, f"trial={trial_id}_label-seq") device = torchutils.selectDevice(gpu_dev_id) # Define cross-validation folds dataset_size = len(trial_ids) cv_folds = utils.makeDataSplits(dataset_size, **cv_params) def getSplit(split_idxs): split_data = tuple( tuple(s[i] for i in split_idxs) for s in (feature_seqs, label_seqs, trial_ids)) return split_data for cv_index, cv_splits in enumerate(cv_folds): train_data, val_data, test_data = tuple(map(getSplit, cv_splits)) train_feats, train_labels, train_ids = train_data train_set = torchutils.SequenceDataset(train_feats, train_labels, device=device, labels_dtype=torch.long, seq_ids=train_ids, transpose_data=True) train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True) test_feats, test_labels, test_ids = test_data test_set = torchutils.SequenceDataset(test_feats, test_labels, device=device, labels_dtype=torch.long, seq_ids=test_ids, transpose_data=True) test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False) val_feats, val_labels, val_ids = val_data val_set = torchutils.SequenceDataset(val_feats, val_labels, device=device, labels_dtype=torch.long, seq_ids=val_ids, transpose_data=True) val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=True) logger.info( f'CV fold {cv_index + 1} / {len(cv_folds)}: {len(trial_ids)} total ' f'({len(train_ids)} train, {len(val_ids)} val, {len(test_ids)} test)' ) input_dim = train_set.num_obsv_dims output_dim = train_set.num_label_types if model_name == 'linear': model = torchutils.LinearClassifier( input_dim, output_dim, **model_params).to(device=device) elif model_name == 'dummy': model = DummyClassifier(input_dim, output_dim, **model_params) elif model_name == 'AssemblyClassifier': model = AssemblyClassifier(assembly_vocab, link_vocab, eq_classes=eq_classes, **model_params) else: raise AssertionError() criterion = torch.nn.CrossEntropyLoss() if model_name != 'dummy': train_epoch_log = collections.defaultdict(list) val_epoch_log = collections.defaultdict(list) metric_dict = { 'Avg Loss': metrics.AverageLoss(), 'Accuracy': metrics.Accuracy(), 'Precision': metrics.Precision(), 'Recall': metrics.Recall(), 'F1': metrics.Fmeasure() } optimizer_ft = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft, step_size=1, gamma=1.00) model, last_model_wts = torchutils.trainModel( model, criterion, optimizer_ft, lr_scheduler, train_loader, val_loader, device=device, metrics=metric_dict, train_epoch_log=train_epoch_log, val_epoch_log=val_epoch_log, **train_params) logger.info(f'scale={float(model._scale)}') logger.info(f'alpha={float(model._alpha)}') # Test model metric_dict = { 'Avg Loss': metrics.AverageLoss(), 'Accuracy': metrics.Accuracy(), 'Precision': metrics.Precision(), 'Recall': metrics.Recall(), 'F1': metrics.Fmeasure() } test_io_history = torchutils.predictSamples( model.to(device=device), test_loader, criterion=criterion, device=device, metrics=metric_dict, data_labeled=True, update_model=False, seq_as_batch=train_params['seq_as_batch'], return_io_history=True) metric_str = ' '.join(str(m) for m in metric_dict.values()) logger.info('[TST] ' + metric_str) d = {k: v.value for k, v in metric_dict.items()} utils.writeResults(results_file, d, sweep_param_name, model_params) if plot_predictions: io_fig_dir = os.path.join(fig_dir, 'model-io') if not os.path.exists(io_fig_dir): os.makedirs(io_fig_dir) label_names = ('gt', 'pred') preds, scores, inputs, gt_labels, ids = zip(*test_io_history) for batch in test_io_history: batch = tuple( x.cpu().numpy() if isinstance(x, torch.Tensor) else x for x in batch) for preds, _, inputs, gt_labels, seq_id in zip(*batch): fn = os.path.join(io_fig_dir, f"trial={seq_id}_model-io.png") utils.plot_array(inputs.sum(axis=-1), (gt_labels, preds), label_names, fn=fn, **viz_params) def saveTrialData(pred_seq, score_seq, feat_seq, label_seq, trial_id): saveVariable(pred_seq, f'trial={trial_id}_pred-label-seq') saveVariable(score_seq, f'trial={trial_id}_score-seq') saveVariable(label_seq, f'trial={trial_id}_true-label-seq') for batch in test_io_history: batch = tuple(x.cpu().numpy() if isinstance(x, torch.Tensor) else x for x in batch) for io in zip(*batch): saveTrialData(*io) saveVariable(train_ids, f'cvfold={cv_index}_train-ids') saveVariable(test_ids, f'cvfold={cv_index}_test-ids') saveVariable(val_ids, f'cvfold={cv_index}_val-ids') saveVariable(train_epoch_log, f'cvfold={cv_index}_{model_name}-train-epoch-log') saveVariable(val_epoch_log, f'cvfold={cv_index}_{model_name}-val-epoch-log') saveVariable(metric_dict, f'cvfold={cv_index}_{model_name}-metric-dict') saveVariable(model, f'cvfold={cv_index}_{model_name}-best') train_fig_dir = os.path.join(fig_dir, 'train-plots') if not os.path.exists(train_fig_dir): os.makedirs(train_fig_dir) if train_epoch_log: torchutils.plotEpochLog(train_epoch_log, subfig_size=(10, 2.5), title='Training performance', fn=os.path.join( train_fig_dir, f'cvfold={cv_index}_train-plot.png')) if val_epoch_log: torchutils.plotEpochLog(val_epoch_log, subfig_size=(10, 2.5), title='Heldout performance', fn=os.path.join( train_fig_dir, f'cvfold={cv_index}_val-plot.png'))