Beispiel #1
0
 def make_dataset(feats, labels, ids, shuffle=True):
     dataset = torchutils.SequenceDataset(feats,
                                          labels,
                                          device=device,
                                          labels_dtype=labels_dtype,
                                          seq_ids=ids,
                                          **dataset_params)
     loader = torch.utils.data.DataLoader(dataset,
                                          batch_size=batch_size,
                                          shuffle=True)
     return dataset, loader
Beispiel #2
0
def main(out_dir=None,
         data_dir=None,
         cv_data_dir=None,
         score_dirs=[],
         fusion_method='sum',
         prune_imu=None,
         standardize=None,
         decode=None,
         plot_predictions=None,
         results_file=None,
         sweep_param_name=None,
         gpu_dev_id=None,
         model_params={},
         cv_params={},
         train_params={},
         viz_params={}):

    data_dir = os.path.expanduser(data_dir)
    out_dir = os.path.expanduser(out_dir)
    score_dirs = tuple(map(os.path.expanduser, score_dirs))
    if cv_data_dir is not None:
        cv_data_dir = os.path.expanduser(cv_data_dir)

    logger = utils.setupRootLogger(filename=os.path.join(out_dir, 'log.txt'))

    if results_file is None:
        results_file = os.path.join(out_dir, 'results.csv')
    else:
        results_file = os.path.expanduser(results_file)

    fig_dir = os.path.join(out_dir, 'figures')
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    out_data_dir = os.path.join(out_dir, 'data')
    if not os.path.exists(out_data_dir):
        os.makedirs(out_data_dir)

    def saveVariable(var, var_name):
        joblib.dump(var, os.path.join(out_data_dir, f'{var_name}.pkl'))

    def loadAll(seq_ids, var_name, data_dir):
        def loadOne(seq_id):
            fn = os.path.join(data_dir, f'trial={seq_id}_{var_name}')
            return joblib.load(fn)

        return tuple(map(loadOne, seq_ids))

    device = torchutils.selectDevice(gpu_dev_id)

    # Load data
    dir_trial_ids = tuple(
        set(utils.getUniqueIds(d, prefix='trial=', to_array=True))
        for d in score_dirs)
    dir_trial_ids += (set(
        utils.getUniqueIds(data_dir, prefix='trial=', to_array=True)), )
    trial_ids = np.array(list(sorted(set.intersection(*dir_trial_ids))))

    for dir_name, t_ids in zip(score_dirs + (data_dir, ), dir_trial_ids):
        logger.info(f"{len(t_ids)} trial ids from {dir_name}:")
        logger.info(f"  {t_ids}")
    logger.info(f"{len(trial_ids)} trials in intersection: {trial_ids}")

    assembly_seqs = loadAll(trial_ids, 'assembly-seq.pkl', data_dir)
    feature_seqs = tuple(
        loadAll(trial_ids, 'data-scores.pkl', d) for d in score_dirs)
    feature_seqs = tuple(zip(*feature_seqs))

    # Combine feature seqs
    include_indices = []
    for i, seq_feats in enumerate(feature_seqs):
        feat_shapes = tuple(f.shape for f in seq_feats)
        include_seq = all(f == feat_shapes[0] for f in feat_shapes)
        if include_seq:
            include_indices.append(i)
        else:
            warn_str = (
                f'Excluding trial {trial_ids[i]} with mismatched feature shapes: '
                f'{feat_shapes}')
            logger.warning(warn_str)

    trial_ids = trial_ids[include_indices]
    assembly_seqs = tuple(assembly_seqs[i] for i in include_indices)
    feature_seqs = tuple(feature_seqs[i] for i in include_indices)

    feature_seqs = tuple(np.stack(f) for f in feature_seqs)

    # Define cross-validation folds
    if cv_data_dir is None:
        dataset_size = len(trial_ids)
        cv_folds = utils.makeDataSplits(dataset_size, **cv_params)
        cv_fold_trial_ids = tuple(
            tuple(map(lambda x: trial_ids[x], splits)) for splits in cv_folds)
    else:
        fn = os.path.join(cv_data_dir, 'cv-fold-trial-ids.pkl')
        cv_fold_trial_ids = joblib.load(fn)

    def getSplit(split_idxs):
        split_data = tuple(
            tuple(s[i] for i in split_idxs)
            for s in (feature_seqs, assembly_seqs, trial_ids))
        return split_data

    gt_scores = []
    all_scores = []
    num_keyframes_total = 0
    num_rgb_errors_total = 0
    num_correctable_errors_total = 0
    num_oov_total = 0
    num_changed_total = 0
    for cv_index, (train_ids, test_ids) in enumerate(cv_fold_trial_ids):
        try:
            test_idxs = np.array(
                [trial_ids.tolist().index(i) for i in test_ids])
            include_indices.append(cv_index)
        except ValueError:
            logger.info(
                f"  Skipping fold {cv_index}: missing test data {test_ids}")

        logger.info(f'CV fold {cv_index + 1}: {len(trial_ids)} total '
                    f'({len(train_ids)} train, {len(test_ids)} test)')

        # TRAIN PHASE
        if cv_data_dir is None:
            train_idxs = np.array([trial_ids.index(i) for i in train_ids])
            train_assembly_seqs = tuple(assembly_seqs[i] for i in train_idxs)
            train_assemblies = []
            for seq in train_assembly_seqs:
                list(
                    labels.gen_eq_classes(seq,
                                          train_assemblies,
                                          equivalent=None))
            model = None
        else:
            fn = f'cvfold={cv_index}_train-assemblies.pkl'
            train_assemblies = joblib.load(os.path.join(cv_data_dir, fn))
            train_idxs = [
                i for i in range(len(trial_ids)) if i not in test_idxs
            ]

            fn = f'cvfold={cv_index}_model.pkl'
            model = joblib.load(os.path.join(cv_data_dir, fn))

        train_features, train_assembly_seqs, train_ids = getSplit(train_idxs)

        if False:
            train_labels = tuple(
                np.array(
                    list(
                        labels.gen_eq_classes(assembly_seq,
                                              train_assemblies,
                                              equivalent=None)), )
                for assembly_seq in train_assembly_seqs)

            train_set = torchutils.SequenceDataset(train_features,
                                                   train_labels,
                                                   seq_ids=train_ids,
                                                   device=device)
            train_loader = torch.utils.data.DataLoader(train_set,
                                                       batch_size=1,
                                                       shuffle=True)

            train_epoch_log = collections.defaultdict(list)
            # val_epoch_log = collections.defaultdict(list)
            metric_dict = {
                'Avg Loss': metrics.AverageLoss(),
                'Accuracy': metrics.Accuracy()
            }

            criterion = torch.nn.CrossEntropyLoss()
            optimizer_ft = torch.optim.Adam(model.parameters(),
                                            lr=1e-3,
                                            betas=(0.9, 0.999),
                                            eps=1e-08,
                                            weight_decay=0,
                                            amsgrad=False)
            lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft,
                                                           step_size=1,
                                                           gamma=1.00)

            model = FusionClassifier(num_sources=train_features[0].shape[0])
            model, last_model_wts = torchutils.trainModel(
                model,
                criterion,
                optimizer_ft,
                lr_scheduler,
                train_loader,  # val_loader,
                device=device,
                metrics=metric_dict,
                train_epoch_log=train_epoch_log,
                # val_epoch_log=val_epoch_log,
                **train_params)

            torchutils.plotEpochLog(train_epoch_log,
                                    subfig_size=(10, 2.5),
                                    title='Training performance',
                                    fn=os.path.join(
                                        fig_dir,
                                        f'cvfold={cv_index}_train-plot.png'))

        test_assemblies = train_assemblies.copy()
        for feature_seq, gt_assembly_seq, trial_id in zip(
                *getSplit(test_idxs)):
            gt_seq = np.array(
                list(
                    labels.gen_eq_classes(gt_assembly_seq,
                                          test_assemblies,
                                          equivalent=None)))

        if plot_predictions:
            assembly_fig_dir = os.path.join(fig_dir, 'assembly-imgs')
            if not os.path.exists(assembly_fig_dir):
                os.makedirs(assembly_fig_dir)
            for i, assembly in enumerate(test_assemblies):
                assembly.draw(assembly_fig_dir, i)

        # TEST PHASE
        accuracies = []
        for feature_seq, gt_assembly_seq, trial_id in zip(
                *getSplit(test_idxs)):
            gt_seq = np.array(
                list(
                    labels.gen_eq_classes(gt_assembly_seq,
                                          test_assemblies,
                                          equivalent=None)))

            num_labels = gt_seq.shape[0]
            num_features = feature_seq.shape[-1]
            if num_labels != num_features:
                err_str = (f"Skipping trial {trial_id}: "
                           f"{num_labels} labels != {num_features} features")
                logger.info(err_str)
                continue

            # Ignore OOV states in ground-truth
            sample_idxs = np.arange(feature_seq.shape[-1])
            score_idxs = gt_seq[gt_seq < feature_seq.shape[1]]
            sample_idxs = sample_idxs[gt_seq < feature_seq.shape[1]]
            gt_scores.append(feature_seq[:, score_idxs, sample_idxs])
            all_scores.append(feature_seq.reshape(feature_seq.shape[0], -1))

            if fusion_method == 'sum':
                score_seq = feature_seq.sum(axis=0)
            elif fusion_method == 'rgb_only':
                score_seq = feature_seq[1]
            elif fusion_method == 'imu_only':
                score_seq = feature_seq[0]
            else:
                raise NotImplementedError()

            if not decode:
                model = None

            if model is None:
                pred_seq = score_seq.argmax(axis=0)
            elif isinstance(model, torch.nn.Module):
                inputs = torch.tensor(feature_seq[None, ...],
                                      dtype=torch.float,
                                      device=device)
                outputs = model.forward(inputs)
                pred_seq = model.predict(outputs)[0].cpu().numpy()
            else:
                dummy_samples = np.arange(score_seq.shape[1])
                pred_seq, _, _, _ = model.viterbi(dummy_samples,
                                                  log_likelihoods=score_seq,
                                                  ml_decode=(not decode))

            pred_assemblies = [train_assemblies[i] for i in pred_seq]
            gt_assemblies = [test_assemblies[i] for i in gt_seq]

            acc = metrics.accuracy_upto(pred_assemblies,
                                        gt_assemblies,
                                        equivalence=None)
            accuracies.append(acc)

            rgb_pred_seq = feature_seq[1].argmax(axis=0)
            num_changed = np.sum(rgb_pred_seq != pred_seq)
            rgb_is_wrong = rgb_pred_seq != gt_seq
            num_rgb_errors = np.sum(rgb_is_wrong)
            imu_scores = feature_seq[0]
            imu_scores_gt = np.array([
                imu_scores[s_idx,
                           t] if s_idx < imu_scores.shape[0] else -np.inf
                for t, s_idx in enumerate(gt_seq)
            ])
            imu_scores_rgb = np.array([
                imu_scores[s_idx,
                           t] if s_idx < imu_scores.shape[0] else -np.inf
                for t, s_idx in enumerate(rgb_pred_seq)
            ])
            # imu_scores_gt = imu_scores[gt_seq, range(len(rgb_pred_seq))]
            best_imu_scores = imu_scores.max(axis=0)
            imu_is_right = imu_scores_gt >= best_imu_scores
            rgb_pred_score_is_lower = imu_scores_gt > imu_scores_rgb
            is_correctable_error = rgb_is_wrong & imu_is_right & rgb_pred_score_is_lower
            num_correctable_errors = np.sum(is_correctable_error)
            prop_correctable = num_correctable_errors / num_rgb_errors

            num_oov = np.sum(gt_seq >= len(train_assemblies))
            num_states = len(gt_seq)

            num_keyframes_total += num_states
            num_rgb_errors_total += num_rgb_errors
            num_correctable_errors_total += num_correctable_errors
            num_oov_total += num_oov
            num_changed_total += num_changed

            logger.info(f"  trial {trial_id}: {num_states} keyframes")
            logger.info(f"    accuracy (fused): {acc * 100:.1f}%")
            logger.info(
                f"    {num_oov} OOV states ({num_oov / num_states * 100:.1f}%)"
            )
            logger.info(
                f"    {num_rgb_errors} RGB errors; "
                f"{num_correctable_errors} correctable from IMU ({prop_correctable * 100:.1f}%)"
            )

            saveVariable(score_seq, f'trial={trial_id}_data-scores')
            saveVariable(pred_assemblies,
                         f'trial={trial_id}_pred-assembly-seq')
            saveVariable(gt_assemblies, f'trial={trial_id}_gt-assembly-seq')

            if plot_predictions:
                io_figs_dir = os.path.join(fig_dir, 'system-io')
                if not os.path.exists(io_figs_dir):
                    os.makedirs(io_figs_dir)
                fn = os.path.join(io_figs_dir, f'trial={trial_id:03}.png')
                utils.plot_array(feature_seq, (gt_seq, pred_seq, score_seq),
                                 ('gt', 'pred', 'scores'),
                                 fn=fn)

                score_figs_dir = os.path.join(fig_dir, 'modality-scores')
                if not os.path.exists(score_figs_dir):
                    os.makedirs(score_figs_dir)
                plot_scores(feature_seq,
                            k=25,
                            fn=os.path.join(score_figs_dir,
                                            f"trial={trial_id:03}.png"))

                paths_dir = os.path.join(fig_dir, 'path-imgs')
                if not os.path.exists(paths_dir):
                    os.makedirs(paths_dir)
                assemblystats.drawPath(pred_seq, trial_id,
                                       f"trial={trial_id}_pred-seq", paths_dir,
                                       assembly_fig_dir)
                assemblystats.drawPath(gt_seq, trial_id,
                                       f"trial={trial_id}_gt-seq", paths_dir,
                                       assembly_fig_dir)

                label_seqs = (gt_seq, ) + tuple(
                    scores.argmax(axis=0) for scores in feature_seq)
                label_seqs = np.row_stack(label_seqs)
                k = 10
                for i, scores in enumerate(feature_seq):
                    label_score_seqs = tuple(
                        np.array([
                            scores[s_idx,
                                   t] if s_idx < scores.shape[0] else -np.inf
                            for t, s_idx in enumerate(label_seq)
                        ]) for label_seq in label_seqs)
                    label_score_seqs = np.row_stack(label_score_seqs)
                    drawPaths(label_seqs,
                              f"trial={trial_id}_pred-scores_modality={i}",
                              paths_dir,
                              assembly_fig_dir,
                              path_scores=label_score_seqs)

                    topk_seq = (-scores).argsort(axis=0)[:k, :]
                    path_scores = np.column_stack(
                        tuple(scores[idxs, i]
                              for i, idxs in enumerate(topk_seq.T)))
                    drawPaths(topk_seq,
                              f"trial={trial_id}_topk_modality={i}",
                              paths_dir,
                              assembly_fig_dir,
                              path_scores=path_scores)
                label_score_seqs = tuple(
                    np.array([
                        score_seq[s_idx,
                                  t] if s_idx < score_seq.shape[0] else -np.inf
                        for t, s_idx in enumerate(label_seq)
                    ]) for label_seq in label_seqs)
                label_score_seqs = np.row_stack(label_score_seqs)
                drawPaths(label_seqs,
                          f"trial={trial_id}_pred-scores_fused",
                          paths_dir,
                          assembly_fig_dir,
                          path_scores=label_score_seqs)
                topk_seq = (-score_seq).argsort(axis=0)[:k, :]
                path_scores = np.column_stack(
                    tuple(score_seq[idxs, i]
                          for i, idxs in enumerate(topk_seq.T)))
                drawPaths(topk_seq,
                          f"trial={trial_id}_topk_fused",
                          paths_dir,
                          assembly_fig_dir,
                          path_scores=path_scores)

        if accuracies:
            fold_accuracy = float(np.array(accuracies).mean())
            # logger.info(f'  acc: {fold_accuracy * 100:.1f}%')
            metric_dict = {'Accuracy': fold_accuracy}
            utils.writeResults(results_file, metric_dict, sweep_param_name,
                               model_params)

    num_unexplained_errors = num_rgb_errors_total - (
        num_oov_total + num_correctable_errors_total)
    prop_correctable = num_correctable_errors_total / num_rgb_errors_total
    prop_oov = num_oov_total / num_rgb_errors_total
    prop_unexplained = num_unexplained_errors / num_rgb_errors_total
    prop_changed = num_changed_total / num_keyframes_total
    logger.info("PERFORMANCE ANALYSIS")
    logger.info(
        f"  {num_rgb_errors_total} / {num_keyframes_total} "
        f"RGB errors ({num_rgb_errors_total / num_keyframes_total * 100:.1f}%)"
    )
    logger.info(f"  {num_oov_total} / {num_rgb_errors_total} "
                f"RGB errors are OOV ({prop_oov * 100:.1f}%)")
    logger.info(f"  {num_correctable_errors_total} / {num_rgb_errors_total} "
                f"RGB errors are correctable ({prop_correctable * 100:.1f}%)")
    logger.info(f"  {num_unexplained_errors} / {num_rgb_errors_total} "
                f"RGB errors are unexplained ({prop_unexplained * 100:.1f}%)")
    logger.info(
        f"  {num_changed_total} / {num_keyframes_total} "
        f"Predictions changed after fusion ({prop_changed * 100:.1f}%)")

    gt_scores = np.hstack(tuple(gt_scores))
    plot_hists(np.exp(gt_scores),
               fn=os.path.join(fig_dir, "score-hists_gt.png"))
    all_scores = np.hstack(tuple(all_scores))
    plot_hists(np.exp(all_scores),
               fn=os.path.join(fig_dir, "score-hists_all.png"))
Beispiel #3
0
def main(out_dir=None,
         gpu_dev_id=None,
         num_samples=10,
         random_seed=None,
         learning_rate=1e-3,
         num_epochs=500,
         dataset_kwargs={},
         dataloader_kwargs={},
         model_kwargs={}):

    if out_dir is None:
        out_dir = os.path.join('~', 'data', 'output', 'seqtools', 'test_gtn')

    out_dir = os.path.expanduser(out_dir)

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    fig_dir = os.path.join(out_dir, 'figures')
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    vocabulary = ['a', 'b', 'c', 'd', 'e']

    transition = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 1, 0], [0, 0, 0, 0, 1],
                           [0, 1, 0, 0, 1], [0, 0, 0, 0, 0]],
                          dtype=float)
    initial = np.array([1, 0, 1, 0, 0], dtype=float)
    final = np.array([0, 1, 0, 0, 1], dtype=float) / 10

    seq_params = (transition, initial, final)
    simulated_dataset = simulate(num_samples, *seq_params)
    label_seqs, obsv_seqs = tuple(zip(*simulated_dataset))
    seq_params = tuple(map(lambda x: -np.log(x), seq_params))

    dataset = torchutils.SequenceDataset(obsv_seqs, label_seqs,
                                         **dataset_kwargs)
    data_loader = torch.utils.data.DataLoader(dataset, **dataloader_kwargs)

    train_loader = data_loader
    val_loader = data_loader

    transition_weights = torch.tensor(transition, dtype=torch.float).log()
    initial_weights = torch.tensor(initial, dtype=torch.float).log()
    final_weights = torch.tensor(final, dtype=torch.float).log()

    model = libfst.LatticeCrf(vocabulary,
                              transition_weights=transition_weights,
                              initial_weights=initial_weights,
                              final_weights=final_weights,
                              debug_output_dir=fig_dir,
                              **model_kwargs)

    gtn.draw(model._transition_fst,
             os.path.join(fig_dir, 'transitions-init.png'),
             isymbols=model._arc_symbols,
             osymbols=model._arc_symbols)

    gtn.draw(model._duration_fst,
             os.path.join(fig_dir, 'durations-init.png'),
             isymbols=model._arc_symbols,
             osymbols=model._arc_symbols)

    if True:
        for i, (inputs, targets, seq_id) in enumerate(train_loader):
            arc_scores = model.scores_to_arc(inputs)
            arc_labels = model.labels_to_arc(targets)

            batch_size, num_samples, num_classes = arc_scores.shape

            obs_fst = libfst.linearFstFromArray(arc_scores[0].reshape(
                num_samples, -1))
            gt_fst = libfst.fromSequence(arc_labels[0])
            d1_fst = gtn.compose(obs_fst, model._duration_fst)
            d1_fst = gtn.project_output(d1_fst)
            denom_fst = gtn.compose(d1_fst, model._transition_fst)
            # denom_fst = gtn.project_output(denom_fst)
            num_fst = gtn.compose(denom_fst, gt_fst)
            viterbi_fst = gtn.viterbi_path(denom_fst)
            pred_fst = gtn.remove(gtn.project_output(viterbi_fst))

            loss = gtn.subtract(gtn.forward_score(num_fst),
                                gtn.forward_score(denom_fst))
            loss = torch.tensor(loss.item())

            if torch.isinf(loss).any():
                denom_alt = gtn.compose(obs_fst, model._transition_fst)
                d1_min = gtn.remove(gtn.project_output(d1_fst))
                denom_alt = gtn.compose(d1_min, model._transition_fst)
                num_alt = gtn.compose(denom_alt, gt_fst)
                gtn.draw(obs_fst,
                         os.path.join(fig_dir, 'observations-init.png'),
                         isymbols=model._arc_symbols,
                         osymbols=model._arc_symbols)
                gtn.draw(gt_fst,
                         os.path.join(fig_dir, 'labels-init.png'),
                         isymbols=model._arc_symbols,
                         osymbols=model._arc_symbols)
                gtn.draw(d1_fst,
                         os.path.join(fig_dir, 'd1-init.png'),
                         isymbols=model._arc_symbols,
                         osymbols=model._arc_symbols)
                gtn.draw(d1_min,
                         os.path.join(fig_dir, 'd1-min-init.png'),
                         isymbols=model._arc_symbols,
                         osymbols=model._arc_symbols)
                gtn.draw(denom_fst,
                         os.path.join(fig_dir, 'denominator-init.png'),
                         isymbols=model._arc_symbols,
                         osymbols=model._arc_symbols)
                gtn.draw(denom_alt,
                         os.path.join(fig_dir, 'denominator-alt-init.png'),
                         isymbols=model._arc_symbols,
                         osymbols=model._arc_symbols)
                gtn.draw(num_fst,
                         os.path.join(fig_dir, 'numerator-init.png'),
                         isymbols=model._arc_symbols,
                         osymbols=model._arc_symbols)
                gtn.draw(num_alt,
                         os.path.join(fig_dir, 'numerator-alt-init.png'),
                         isymbols=model._arc_symbols,
                         osymbols=model._arc_symbols)
                gtn.draw(viterbi_fst,
                         os.path.join(fig_dir, 'viterbi-init.png'),
                         isymbols=model._arc_symbols,
                         osymbols=model._arc_symbols)
                gtn.draw(pred_fst,
                         os.path.join(fig_dir, 'pred-init.png'),
                         isymbols=model._arc_symbols,
                         osymbols=model._arc_symbols)
                import pdb
                pdb.set_trace()

    # Train the model
    train_epoch_log = collections.defaultdict(list)
    val_epoch_log = collections.defaultdict(list)
    metric_dict = {
        'Avg Loss': metrics.AverageLoss(),
        'Accuracy': metrics.Accuracy()
    }

    criterion = model.nllLoss
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=1,
                                                gamma=1.00)

    model, last_model_wts = torchutils.trainModel(
        model,
        criterion,
        optimizer,
        scheduler,
        train_loader,
        val_loader,
        metrics=metric_dict,
        test_metric='Avg Loss',
        train_epoch_log=train_epoch_log,
        val_epoch_log=val_epoch_log,
        num_epochs=num_epochs)

    gtn.draw(model._transition_fst,
             os.path.join(fig_dir, 'transitions-trained.png'),
             isymbols=model._arc_symbols,
             osymbols=model._arc_symbols)
    gtn.draw(model._duration_fst,
             os.path.join(fig_dir, 'durations-trained.png'),
             isymbols=model._arc_symbols,
             osymbols=model._arc_symbols)

    torchutils.plotEpochLog(train_epoch_log,
                            title="Train Epoch Log",
                            fn=os.path.join(fig_dir, "train-log.png"))
Beispiel #4
0
def main(out_dir=None,
         data_dir=None,
         model_name=None,
         pretrained_model_dir=None,
         gpu_dev_id=None,
         batch_size=None,
         learning_rate=None,
         independent_signals=None,
         active_only=None,
         model_params={},
         cv_params={},
         train_params={},
         viz_params={},
         plot_predictions=None,
         results_file=None,
         sweep_param_name=None,
         label_mapping=None,
         eval_label_mapping=None):

    data_dir = os.path.expanduser(data_dir)
    out_dir = os.path.expanduser(out_dir)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    logger = utils.setupRootLogger(filename=os.path.join(out_dir, 'log.txt'))

    if results_file is None:
        results_file = os.path.join(out_dir, 'results.csv')
    else:
        results_file = os.path.expanduser(results_file)

    fig_dir = os.path.join(out_dir, 'figures')
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    out_data_dir = os.path.join(out_dir, 'data')
    if not os.path.exists(out_data_dir):
        os.makedirs(out_data_dir)

    def saveVariable(var, var_name):
        joblib.dump(var, os.path.join(out_data_dir, f'{var_name}.pkl'))

    def loadAll(seq_ids, var_name, data_dir):
        def loadOne(seq_id):
            fn = os.path.join(data_dir, f'trial={seq_id}_{var_name}')
            return joblib.load(fn)

        return tuple(map(loadOne, seq_ids))

    # Load data
    trial_ids = utils.getUniqueIds(data_dir, prefix='trial=', to_array=True)
    feature_seqs = loadAll(trial_ids, 'feature-seq.pkl', data_dir)
    label_seqs = loadAll(trial_ids, 'label-seq.pkl', data_dir)

    device = torchutils.selectDevice(gpu_dev_id)

    if label_mapping is not None:

        def map_labels(labels):
            for i, j in label_mapping.items():
                labels[labels == i] = j
            return labels

        label_seqs = tuple(map(map_labels, label_seqs))

    # Define cross-validation folds
    dataset_size = len(trial_ids)
    cv_folds = utils.makeDataSplits(dataset_size, **cv_params)

    def getSplit(split_idxs):
        split_data = tuple(
            tuple(s[i] for i in split_idxs)
            for s in (feature_seqs, label_seqs, trial_ids))
        return split_data

    for cv_index, cv_splits in enumerate(cv_folds):
        if pretrained_model_dir is not None:

            def loadFromPretrain(fn):
                return joblib.load(
                    os.path.join(pretrained_model_dir, f"{fn}.pkl"))

            model = loadFromPretrain(f'cvfold={cv_index}_{model_name}-best')
            train_ids = loadFromPretrain(f'cvfold={cv_index}_train-ids')
            val_ids = loadFromPretrain(f'cvfold={cv_index}_val-ids')
            test_ids = tuple(i for i in trial_ids
                             if i not in (train_ids + val_ids))
            test_idxs = tuple(trial_ids.tolist().index(i) for i in test_ids)
            test_data = getSplit(test_idxs)

            if independent_signals:
                criterion = torch.nn.CrossEntropyLoss()
                labels_dtype = torch.long
                test_data = splitSeqs(*test_data, active_only=False)
            else:
                # FIXME
                # criterion = torch.nn.BCEWithLogitsLoss()
                # labels_dtype = torch.float
                criterion = torch.nn.CrossEntropyLoss()
                labels_dtype = torch.long

            test_feats, test_labels, test_ids = test_data
            test_set = torchutils.SequenceDataset(test_feats,
                                                  test_labels,
                                                  device=device,
                                                  labels_dtype=labels_dtype,
                                                  seq_ids=test_ids,
                                                  transpose_data=True)
            test_loader = torch.utils.data.DataLoader(test_set,
                                                      batch_size=batch_size,
                                                      shuffle=False)

            # Test model
            metric_dict = {
                'Avg Loss': metrics.AverageLoss(),
                'Accuracy': metrics.Accuracy(),
                'Precision': metrics.Precision(),
                'Recall': metrics.Recall(),
                'F1': metrics.Fmeasure()
            }
            test_io_history = torchutils.predictSamples(
                model.to(device=device),
                test_loader,
                criterion=criterion,
                device=device,
                metrics=metric_dict,
                data_labeled=True,
                update_model=False,
                seq_as_batch=train_params['seq_as_batch'],
                return_io_history=True)
            if independent_signals:
                test_io_history = tuple(joinSeqs(test_io_history))

            metric_str = '  '.join(str(m) for m in metric_dict.values())
            logger.info('[TST]  ' + metric_str)

            d = {k: v.value for k, v in metric_dict.items()}
            utils.writeResults(results_file, d, sweep_param_name, model_params)

            if plot_predictions:
                imu.plot_prediction_eg(test_io_history, fig_dir, **viz_params)

            def saveTrialData(pred_seq, score_seq, feat_seq, label_seq,
                              trial_id):
                if label_mapping is not None:

                    def dup_score_cols(scores):
                        num_cols = scores.shape[-1] + len(label_mapping)
                        col_idxs = torch.arange(num_cols)
                        for i, j in label_mapping.items():
                            col_idxs[i] = j
                        return scores[..., col_idxs]

                    score_seq = dup_score_cols(score_seq)
                saveVariable(pred_seq.cpu().numpy(),
                             f'trial={trial_id}_pred-label-seq')
                saveVariable(score_seq.cpu().numpy(),
                             f'trial={trial_id}_score-seq')
                saveVariable(label_seq.cpu().numpy(),
                             f'trial={trial_id}_true-label-seq')

            for io in test_io_history:
                saveTrialData(*io)
            continue

        train_data, val_data, test_data = tuple(map(getSplit, cv_splits))

        if independent_signals:
            criterion = torch.nn.CrossEntropyLoss()
            labels_dtype = torch.long
            split_ = functools.partial(splitSeqs, active_only=active_only)
            train_data = split_(*train_data)
            val_data = split_(*val_data)
            test_data = splitSeqs(*test_data, active_only=False)
        else:
            # FIXME
            # criterion = torch.nn.BCEWithLogitsLoss()
            # labels_dtype = torch.float
            criterion = torch.nn.CrossEntropyLoss()
            labels_dtype = torch.long

        train_feats, train_labels, train_ids = train_data
        train_set = torchutils.SequenceDataset(train_feats,
                                               train_labels,
                                               device=device,
                                               labels_dtype=labels_dtype,
                                               seq_ids=train_ids,
                                               transpose_data=True)
        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=batch_size,
                                                   shuffle=True)

        test_feats, test_labels, test_ids = test_data
        test_set = torchutils.SequenceDataset(test_feats,
                                              test_labels,
                                              device=device,
                                              labels_dtype=labels_dtype,
                                              seq_ids=test_ids,
                                              transpose_data=True)
        test_loader = torch.utils.data.DataLoader(test_set,
                                                  batch_size=batch_size,
                                                  shuffle=False)

        val_feats, val_labels, val_ids = val_data
        val_set = torchutils.SequenceDataset(val_feats,
                                             val_labels,
                                             device=device,
                                             labels_dtype=labels_dtype,
                                             seq_ids=val_ids,
                                             transpose_data=True)
        val_loader = torch.utils.data.DataLoader(val_set,
                                                 batch_size=batch_size,
                                                 shuffle=True)

        logger.info(
            f'CV fold {cv_index + 1} / {len(cv_folds)}: {len(trial_ids)} total '
            f'({len(train_ids)} train, {len(val_ids)} val, {len(test_ids)} test)'
        )

        input_dim = train_set.num_obsv_dims
        output_dim = train_set.num_label_types
        if model_name == 'linear':
            model = torchutils.LinearClassifier(
                input_dim, output_dim, **model_params).to(device=device)
        elif model_name == 'conv':
            model = ConvClassifier(input_dim, output_dim,
                                   **model_params).to(device=device)
        elif model_name == 'TCN':
            model = TcnClassifier(input_dim, output_dim, **model_params)
        else:
            raise AssertionError()

        train_epoch_log = collections.defaultdict(list)
        val_epoch_log = collections.defaultdict(list)
        metric_dict = {
            'Avg Loss': metrics.AverageLoss(),
            'Accuracy': metrics.Accuracy(),
            'Precision': metrics.Precision(),
            'Recall': metrics.Recall(),
            'F1': metrics.Fmeasure()
        }

        optimizer_ft = torch.optim.Adam(model.parameters(),
                                        lr=learning_rate,
                                        betas=(0.9, 0.999),
                                        eps=1e-08,
                                        weight_decay=0,
                                        amsgrad=False)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft,
                                                       step_size=1,
                                                       gamma=1.00)

        model, last_model_wts = torchutils.trainModel(
            model,
            criterion,
            optimizer_ft,
            lr_scheduler,
            train_loader,
            val_loader,
            device=device,
            metrics=metric_dict,
            train_epoch_log=train_epoch_log,
            val_epoch_log=val_epoch_log,
            **train_params)

        # Test model
        metric_dict = {
            'Avg Loss': metrics.AverageLoss(),
            'Accuracy': metrics.Accuracy(),
            'Precision': metrics.Precision(),
            'Recall': metrics.Recall(),
            'F1': metrics.Fmeasure()
        }
        test_io_history = torchutils.predictSamples(
            model.to(device=device),
            test_loader,
            criterion=criterion,
            device=device,
            metrics=metric_dict,
            data_labeled=True,
            update_model=False,
            seq_as_batch=train_params['seq_as_batch'],
            return_io_history=True)
        if independent_signals:
            test_io_history = tuple(joinSeqs(test_io_history))

        metric_str = '  '.join(str(m) for m in metric_dict.values())
        logger.info('[TST]  ' + metric_str)

        d = {k: v.value for k, v in metric_dict.items()}
        utils.writeResults(results_file, d, sweep_param_name, model_params)

        if plot_predictions:
            # imu.plot_prediction_eg(test_io_history, fig_dir, fig_type=fig_type, **viz_params)
            imu.plot_prediction_eg(test_io_history, fig_dir, **viz_params)

        def saveTrialData(pred_seq, score_seq, feat_seq, label_seq, trial_id):
            if label_mapping is not None:

                def dup_score_cols(scores):
                    num_cols = scores.shape[-1] + len(label_mapping)
                    col_idxs = torch.arange(num_cols)
                    for i, j in label_mapping.items():
                        col_idxs[i] = j
                    return scores[..., col_idxs]

                score_seq = dup_score_cols(score_seq)
            saveVariable(pred_seq.cpu().numpy(),
                         f'trial={trial_id}_pred-label-seq')
            saveVariable(score_seq.cpu().numpy(),
                         f'trial={trial_id}_score-seq')
            saveVariable(label_seq.cpu().numpy(),
                         f'trial={trial_id}_true-label-seq')

        for io in test_io_history:
            saveTrialData(*io)

        saveVariable(train_ids, f'cvfold={cv_index}_train-ids')
        saveVariable(test_ids, f'cvfold={cv_index}_test-ids')
        saveVariable(val_ids, f'cvfold={cv_index}_val-ids')
        saveVariable(train_epoch_log,
                     f'cvfold={cv_index}_{model_name}-train-epoch-log')
        saveVariable(val_epoch_log,
                     f'cvfold={cv_index}_{model_name}-val-epoch-log')
        saveVariable(metric_dict,
                     f'cvfold={cv_index}_{model_name}-metric-dict')
        saveVariable(model, f'cvfold={cv_index}_{model_name}-best')

        model.load_state_dict(last_model_wts)
        saveVariable(model, f'cvfold={cv_index}_{model_name}-last')

        torchutils.plotEpochLog(train_epoch_log,
                                subfig_size=(10, 2.5),
                                title='Training performance',
                                fn=os.path.join(
                                    fig_dir,
                                    f'cvfold={cv_index}_train-plot.png'))

        if val_epoch_log:
            torchutils.plotEpochLog(val_epoch_log,
                                    subfig_size=(10, 2.5),
                                    title='Heldout performance',
                                    fn=os.path.join(
                                        fig_dir,
                                        f'cvfold={cv_index}_val-plot.png'))

        if eval_label_mapping is not None:
            metric_dict = {
                'Avg Loss': metrics.AverageLoss(),
                'Accuracy': metrics.Accuracy(),
                'Precision': metrics.Precision(),
                'Recall': metrics.Recall(),
                'F1': metrics.Fmeasure()
            }
            test_io_history = torchutils.predictSamples(
                model.to(device=device),
                test_loader,
                criterion=criterion,
                device=device,
                metrics=metric_dict,
                data_labeled=True,
                update_model=False,
                seq_as_batch=train_params['seq_as_batch'],
                return_io_history=True,
                label_mapping=eval_label_mapping)
            if independent_signals:
                test_io_history = joinSeqs(test_io_history)
            metric_str = '  '.join(str(m) for m in metric_dict.values())
            logger.info('[TST]  ' + metric_str)
def main(out_dir=None,
         data_dir=None,
         model_name=None,
         part_symmetries=None,
         gpu_dev_id=None,
         batch_size=None,
         learning_rate=None,
         model_params={},
         cv_params={},
         train_params={},
         viz_params={},
         plot_predictions=None,
         results_file=None,
         sweep_param_name=None):

    if part_symmetries is None:
        part_symmetries = {
            'beam_side': ('backbeam_hole_1', 'backbeam_hole_2',
                          'frontbeam_hole_1', 'frontbeam_hole_2'),
            'beam_top': ('backbeam_hole_3', 'frontbeam_hole_3'),
            'backrest': ('backrest_hole_1', 'backrest_hole_2')
        }

    data_dir = os.path.expanduser(data_dir)
    out_dir = os.path.expanduser(out_dir)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    logger = utils.setupRootLogger(filename=os.path.join(out_dir, 'log.txt'))

    if results_file is None:
        results_file = os.path.join(out_dir, 'results.csv')
    else:
        results_file = os.path.expanduser(results_file)

    fig_dir = os.path.join(out_dir, 'figures')
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    out_data_dir = os.path.join(out_dir, 'data')
    if not os.path.exists(out_data_dir):
        os.makedirs(out_data_dir)

    def saveVariable(var, var_name):
        joblib.dump(var, os.path.join(out_data_dir, f'{var_name}.pkl'))

    def loadAll(seq_ids, var_name, data_dir):
        def loadOne(seq_id):
            fn = os.path.join(data_dir, f'trial={seq_id}_{var_name}')
            return joblib.load(fn)

        return tuple(map(loadOne, seq_ids))

    # Load vocab
    with open(os.path.join(data_dir, "part-vocab.yaml"), 'rt') as f:
        link_vocab = yaml.safe_load(f)
    assembly_vocab = joblib.load(os.path.join(data_dir, 'assembly-vocab.pkl'))

    # Load data
    trial_ids = utils.getUniqueIds(data_dir, prefix='trial=')
    feature_seqs = loadAll(trial_ids, 'feature-seq.pkl', data_dir)
    label_seqs = loadAll(trial_ids, 'label-seq.pkl', data_dir)

    if part_symmetries:
        # Construct equivalence classes from vocab
        eq_classes, assembly_eq_classes, eq_class_vocab = makeEqClasses(
            assembly_vocab, part_symmetries)
        lib_assembly.writeAssemblies(
            os.path.join(fig_dir, 'eq-class-vocab.txt'), eq_class_vocab)
        label_seqs = tuple(assembly_eq_classes[label_seq]
                           for label_seq in label_seqs)
        saveVariable(eq_class_vocab, 'assembly-vocab')
    else:
        eq_classes = None

    def impute_nan(input_seq):
        input_is_nan = np.isnan(input_seq)
        logger.info(f"{input_is_nan.sum()} NaN elements")
        input_seq[input_is_nan] = 0  # np.nanmean(input_seq)
        return input_seq

    # feature_seqs = tuple(map(impute_nan, feature_seqs))

    for trial_id, label_seq, feat_seq in zip(trial_ids, label_seqs,
                                             feature_seqs):
        saveVariable(feat_seq, f"trial={trial_id}_feature-seq")
        saveVariable(label_seq, f"trial={trial_id}_label-seq")

    device = torchutils.selectDevice(gpu_dev_id)

    # Define cross-validation folds
    dataset_size = len(trial_ids)
    cv_folds = utils.makeDataSplits(dataset_size, **cv_params)

    def getSplit(split_idxs):
        split_data = tuple(
            tuple(s[i] for i in split_idxs)
            for s in (feature_seqs, label_seqs, trial_ids))
        return split_data

    for cv_index, cv_splits in enumerate(cv_folds):
        train_data, val_data, test_data = tuple(map(getSplit, cv_splits))

        train_feats, train_labels, train_ids = train_data
        train_set = torchutils.SequenceDataset(train_feats,
                                               train_labels,
                                               device=device,
                                               labels_dtype=torch.long,
                                               seq_ids=train_ids,
                                               transpose_data=True)
        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=batch_size,
                                                   shuffle=True)

        test_feats, test_labels, test_ids = test_data
        test_set = torchutils.SequenceDataset(test_feats,
                                              test_labels,
                                              device=device,
                                              labels_dtype=torch.long,
                                              seq_ids=test_ids,
                                              transpose_data=True)
        test_loader = torch.utils.data.DataLoader(test_set,
                                                  batch_size=batch_size,
                                                  shuffle=False)

        val_feats, val_labels, val_ids = val_data
        val_set = torchutils.SequenceDataset(val_feats,
                                             val_labels,
                                             device=device,
                                             labels_dtype=torch.long,
                                             seq_ids=val_ids,
                                             transpose_data=True)
        val_loader = torch.utils.data.DataLoader(val_set,
                                                 batch_size=batch_size,
                                                 shuffle=True)

        logger.info(
            f'CV fold {cv_index + 1} / {len(cv_folds)}: {len(trial_ids)} total '
            f'({len(train_ids)} train, {len(val_ids)} val, {len(test_ids)} test)'
        )

        input_dim = train_set.num_obsv_dims
        output_dim = train_set.num_label_types
        if model_name == 'linear':
            model = torchutils.LinearClassifier(
                input_dim, output_dim, **model_params).to(device=device)
        elif model_name == 'dummy':
            model = DummyClassifier(input_dim, output_dim, **model_params)
        elif model_name == 'AssemblyClassifier':
            model = AssemblyClassifier(assembly_vocab,
                                       link_vocab,
                                       eq_classes=eq_classes,
                                       **model_params)
        else:
            raise AssertionError()

        criterion = torch.nn.CrossEntropyLoss()
        if model_name != 'dummy':
            train_epoch_log = collections.defaultdict(list)
            val_epoch_log = collections.defaultdict(list)
            metric_dict = {
                'Avg Loss': metrics.AverageLoss(),
                'Accuracy': metrics.Accuracy(),
                'Precision': metrics.Precision(),
                'Recall': metrics.Recall(),
                'F1': metrics.Fmeasure()
            }

            optimizer_ft = torch.optim.Adam(model.parameters(),
                                            lr=learning_rate,
                                            betas=(0.9, 0.999),
                                            eps=1e-08,
                                            weight_decay=0,
                                            amsgrad=False)
            lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft,
                                                           step_size=1,
                                                           gamma=1.00)

            model, last_model_wts = torchutils.trainModel(
                model,
                criterion,
                optimizer_ft,
                lr_scheduler,
                train_loader,
                val_loader,
                device=device,
                metrics=metric_dict,
                train_epoch_log=train_epoch_log,
                val_epoch_log=val_epoch_log,
                **train_params)

        logger.info(f'scale={float(model._scale)}')
        logger.info(f'alpha={float(model._alpha)}')

        # Test model
        metric_dict = {
            'Avg Loss': metrics.AverageLoss(),
            'Accuracy': metrics.Accuracy(),
            'Precision': metrics.Precision(),
            'Recall': metrics.Recall(),
            'F1': metrics.Fmeasure()
        }
        test_io_history = torchutils.predictSamples(
            model.to(device=device),
            test_loader,
            criterion=criterion,
            device=device,
            metrics=metric_dict,
            data_labeled=True,
            update_model=False,
            seq_as_batch=train_params['seq_as_batch'],
            return_io_history=True)

        metric_str = '  '.join(str(m) for m in metric_dict.values())
        logger.info('[TST]  ' + metric_str)

        d = {k: v.value for k, v in metric_dict.items()}
        utils.writeResults(results_file, d, sweep_param_name, model_params)

        if plot_predictions:
            io_fig_dir = os.path.join(fig_dir, 'model-io')
            if not os.path.exists(io_fig_dir):
                os.makedirs(io_fig_dir)

            label_names = ('gt', 'pred')
            preds, scores, inputs, gt_labels, ids = zip(*test_io_history)
            for batch in test_io_history:
                batch = tuple(
                    x.cpu().numpy() if isinstance(x, torch.Tensor) else x
                    for x in batch)
                for preds, _, inputs, gt_labels, seq_id in zip(*batch):
                    fn = os.path.join(io_fig_dir,
                                      f"trial={seq_id}_model-io.png")
                    utils.plot_array(inputs.sum(axis=-1), (gt_labels, preds),
                                     label_names,
                                     fn=fn,
                                     **viz_params)

        def saveTrialData(pred_seq, score_seq, feat_seq, label_seq, trial_id):
            saveVariable(pred_seq, f'trial={trial_id}_pred-label-seq')
            saveVariable(score_seq, f'trial={trial_id}_score-seq')
            saveVariable(label_seq, f'trial={trial_id}_true-label-seq')

        for batch in test_io_history:
            batch = tuple(x.cpu().numpy() if isinstance(x, torch.Tensor) else x
                          for x in batch)
            for io in zip(*batch):
                saveTrialData(*io)

        saveVariable(train_ids, f'cvfold={cv_index}_train-ids')
        saveVariable(test_ids, f'cvfold={cv_index}_test-ids')
        saveVariable(val_ids, f'cvfold={cv_index}_val-ids')
        saveVariable(train_epoch_log,
                     f'cvfold={cv_index}_{model_name}-train-epoch-log')
        saveVariable(val_epoch_log,
                     f'cvfold={cv_index}_{model_name}-val-epoch-log')
        saveVariable(metric_dict,
                     f'cvfold={cv_index}_{model_name}-metric-dict')
        saveVariable(model, f'cvfold={cv_index}_{model_name}-best')

        train_fig_dir = os.path.join(fig_dir, 'train-plots')
        if not os.path.exists(train_fig_dir):
            os.makedirs(train_fig_dir)

        if train_epoch_log:
            torchutils.plotEpochLog(train_epoch_log,
                                    subfig_size=(10, 2.5),
                                    title='Training performance',
                                    fn=os.path.join(
                                        train_fig_dir,
                                        f'cvfold={cv_index}_train-plot.png'))

        if val_epoch_log:
            torchutils.plotEpochLog(val_epoch_log,
                                    subfig_size=(10, 2.5),
                                    title='Heldout performance',
                                    fn=os.path.join(
                                        train_fig_dir,
                                        f'cvfold={cv_index}_val-plot.png'))