Exemplo n.º 1
0
def main(out_dir=None,
         data_dir=None,
         cv_data_dir=None,
         score_dirs=[],
         fusion_method='sum',
         prune_imu=None,
         standardize=None,
         decode=None,
         plot_predictions=None,
         results_file=None,
         sweep_param_name=None,
         gpu_dev_id=None,
         model_params={},
         cv_params={},
         train_params={},
         viz_params={}):

    data_dir = os.path.expanduser(data_dir)
    out_dir = os.path.expanduser(out_dir)
    score_dirs = tuple(map(os.path.expanduser, score_dirs))
    if cv_data_dir is not None:
        cv_data_dir = os.path.expanduser(cv_data_dir)

    logger = utils.setupRootLogger(filename=os.path.join(out_dir, 'log.txt'))

    if results_file is None:
        results_file = os.path.join(out_dir, 'results.csv')
    else:
        results_file = os.path.expanduser(results_file)

    fig_dir = os.path.join(out_dir, 'figures')
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    out_data_dir = os.path.join(out_dir, 'data')
    if not os.path.exists(out_data_dir):
        os.makedirs(out_data_dir)

    def saveVariable(var, var_name):
        joblib.dump(var, os.path.join(out_data_dir, f'{var_name}.pkl'))

    def loadAll(seq_ids, var_name, data_dir):
        def loadOne(seq_id):
            fn = os.path.join(data_dir, f'trial={seq_id}_{var_name}')
            return joblib.load(fn)

        return tuple(map(loadOne, seq_ids))

    device = torchutils.selectDevice(gpu_dev_id)

    # Load data
    dir_trial_ids = tuple(
        set(utils.getUniqueIds(d, prefix='trial=', to_array=True))
        for d in score_dirs)
    dir_trial_ids += (set(
        utils.getUniqueIds(data_dir, prefix='trial=', to_array=True)), )
    trial_ids = np.array(list(sorted(set.intersection(*dir_trial_ids))))

    for dir_name, t_ids in zip(score_dirs + (data_dir, ), dir_trial_ids):
        logger.info(f"{len(t_ids)} trial ids from {dir_name}:")
        logger.info(f"  {t_ids}")
    logger.info(f"{len(trial_ids)} trials in intersection: {trial_ids}")

    assembly_seqs = loadAll(trial_ids, 'assembly-seq.pkl', data_dir)
    feature_seqs = tuple(
        loadAll(trial_ids, 'data-scores.pkl', d) for d in score_dirs)
    feature_seqs = tuple(zip(*feature_seqs))

    # Combine feature seqs
    include_indices = []
    for i, seq_feats in enumerate(feature_seqs):
        feat_shapes = tuple(f.shape for f in seq_feats)
        include_seq = all(f == feat_shapes[0] for f in feat_shapes)
        if include_seq:
            include_indices.append(i)
        else:
            warn_str = (
                f'Excluding trial {trial_ids[i]} with mismatched feature shapes: '
                f'{feat_shapes}')
            logger.warning(warn_str)

    trial_ids = trial_ids[include_indices]
    assembly_seqs = tuple(assembly_seqs[i] for i in include_indices)
    feature_seqs = tuple(feature_seqs[i] for i in include_indices)

    feature_seqs = tuple(np.stack(f) for f in feature_seqs)

    # Define cross-validation folds
    if cv_data_dir is None:
        dataset_size = len(trial_ids)
        cv_folds = utils.makeDataSplits(dataset_size, **cv_params)
        cv_fold_trial_ids = tuple(
            tuple(map(lambda x: trial_ids[x], splits)) for splits in cv_folds)
    else:
        fn = os.path.join(cv_data_dir, 'cv-fold-trial-ids.pkl')
        cv_fold_trial_ids = joblib.load(fn)

    def getSplit(split_idxs):
        split_data = tuple(
            tuple(s[i] for i in split_idxs)
            for s in (feature_seqs, assembly_seqs, trial_ids))
        return split_data

    gt_scores = []
    all_scores = []
    num_keyframes_total = 0
    num_rgb_errors_total = 0
    num_correctable_errors_total = 0
    num_oov_total = 0
    num_changed_total = 0
    for cv_index, (train_ids, test_ids) in enumerate(cv_fold_trial_ids):
        try:
            test_idxs = np.array(
                [trial_ids.tolist().index(i) for i in test_ids])
            include_indices.append(cv_index)
        except ValueError:
            logger.info(
                f"  Skipping fold {cv_index}: missing test data {test_ids}")

        logger.info(f'CV fold {cv_index + 1}: {len(trial_ids)} total '
                    f'({len(train_ids)} train, {len(test_ids)} test)')

        # TRAIN PHASE
        if cv_data_dir is None:
            train_idxs = np.array([trial_ids.index(i) for i in train_ids])
            train_assembly_seqs = tuple(assembly_seqs[i] for i in train_idxs)
            train_assemblies = []
            for seq in train_assembly_seqs:
                list(
                    labels.gen_eq_classes(seq,
                                          train_assemblies,
                                          equivalent=None))
            model = None
        else:
            fn = f'cvfold={cv_index}_train-assemblies.pkl'
            train_assemblies = joblib.load(os.path.join(cv_data_dir, fn))
            train_idxs = [
                i for i in range(len(trial_ids)) if i not in test_idxs
            ]

            fn = f'cvfold={cv_index}_model.pkl'
            model = joblib.load(os.path.join(cv_data_dir, fn))

        train_features, train_assembly_seqs, train_ids = getSplit(train_idxs)

        if False:
            train_labels = tuple(
                np.array(
                    list(
                        labels.gen_eq_classes(assembly_seq,
                                              train_assemblies,
                                              equivalent=None)), )
                for assembly_seq in train_assembly_seqs)

            train_set = torchutils.SequenceDataset(train_features,
                                                   train_labels,
                                                   seq_ids=train_ids,
                                                   device=device)
            train_loader = torch.utils.data.DataLoader(train_set,
                                                       batch_size=1,
                                                       shuffle=True)

            train_epoch_log = collections.defaultdict(list)
            # val_epoch_log = collections.defaultdict(list)
            metric_dict = {
                'Avg Loss': metrics.AverageLoss(),
                'Accuracy': metrics.Accuracy()
            }

            criterion = torch.nn.CrossEntropyLoss()
            optimizer_ft = torch.optim.Adam(model.parameters(),
                                            lr=1e-3,
                                            betas=(0.9, 0.999),
                                            eps=1e-08,
                                            weight_decay=0,
                                            amsgrad=False)
            lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft,
                                                           step_size=1,
                                                           gamma=1.00)

            model = FusionClassifier(num_sources=train_features[0].shape[0])
            model, last_model_wts = torchutils.trainModel(
                model,
                criterion,
                optimizer_ft,
                lr_scheduler,
                train_loader,  # val_loader,
                device=device,
                metrics=metric_dict,
                train_epoch_log=train_epoch_log,
                # val_epoch_log=val_epoch_log,
                **train_params)

            torchutils.plotEpochLog(train_epoch_log,
                                    subfig_size=(10, 2.5),
                                    title='Training performance',
                                    fn=os.path.join(
                                        fig_dir,
                                        f'cvfold={cv_index}_train-plot.png'))

        test_assemblies = train_assemblies.copy()
        for feature_seq, gt_assembly_seq, trial_id in zip(
                *getSplit(test_idxs)):
            gt_seq = np.array(
                list(
                    labels.gen_eq_classes(gt_assembly_seq,
                                          test_assemblies,
                                          equivalent=None)))

        if plot_predictions:
            assembly_fig_dir = os.path.join(fig_dir, 'assembly-imgs')
            if not os.path.exists(assembly_fig_dir):
                os.makedirs(assembly_fig_dir)
            for i, assembly in enumerate(test_assemblies):
                assembly.draw(assembly_fig_dir, i)

        # TEST PHASE
        accuracies = []
        for feature_seq, gt_assembly_seq, trial_id in zip(
                *getSplit(test_idxs)):
            gt_seq = np.array(
                list(
                    labels.gen_eq_classes(gt_assembly_seq,
                                          test_assemblies,
                                          equivalent=None)))

            num_labels = gt_seq.shape[0]
            num_features = feature_seq.shape[-1]
            if num_labels != num_features:
                err_str = (f"Skipping trial {trial_id}: "
                           f"{num_labels} labels != {num_features} features")
                logger.info(err_str)
                continue

            # Ignore OOV states in ground-truth
            sample_idxs = np.arange(feature_seq.shape[-1])
            score_idxs = gt_seq[gt_seq < feature_seq.shape[1]]
            sample_idxs = sample_idxs[gt_seq < feature_seq.shape[1]]
            gt_scores.append(feature_seq[:, score_idxs, sample_idxs])
            all_scores.append(feature_seq.reshape(feature_seq.shape[0], -1))

            if fusion_method == 'sum':
                score_seq = feature_seq.sum(axis=0)
            elif fusion_method == 'rgb_only':
                score_seq = feature_seq[1]
            elif fusion_method == 'imu_only':
                score_seq = feature_seq[0]
            else:
                raise NotImplementedError()

            if not decode:
                model = None

            if model is None:
                pred_seq = score_seq.argmax(axis=0)
            elif isinstance(model, torch.nn.Module):
                inputs = torch.tensor(feature_seq[None, ...],
                                      dtype=torch.float,
                                      device=device)
                outputs = model.forward(inputs)
                pred_seq = model.predict(outputs)[0].cpu().numpy()
            else:
                dummy_samples = np.arange(score_seq.shape[1])
                pred_seq, _, _, _ = model.viterbi(dummy_samples,
                                                  log_likelihoods=score_seq,
                                                  ml_decode=(not decode))

            pred_assemblies = [train_assemblies[i] for i in pred_seq]
            gt_assemblies = [test_assemblies[i] for i in gt_seq]

            acc = metrics.accuracy_upto(pred_assemblies,
                                        gt_assemblies,
                                        equivalence=None)
            accuracies.append(acc)

            rgb_pred_seq = feature_seq[1].argmax(axis=0)
            num_changed = np.sum(rgb_pred_seq != pred_seq)
            rgb_is_wrong = rgb_pred_seq != gt_seq
            num_rgb_errors = np.sum(rgb_is_wrong)
            imu_scores = feature_seq[0]
            imu_scores_gt = np.array([
                imu_scores[s_idx,
                           t] if s_idx < imu_scores.shape[0] else -np.inf
                for t, s_idx in enumerate(gt_seq)
            ])
            imu_scores_rgb = np.array([
                imu_scores[s_idx,
                           t] if s_idx < imu_scores.shape[0] else -np.inf
                for t, s_idx in enumerate(rgb_pred_seq)
            ])
            # imu_scores_gt = imu_scores[gt_seq, range(len(rgb_pred_seq))]
            best_imu_scores = imu_scores.max(axis=0)
            imu_is_right = imu_scores_gt >= best_imu_scores
            rgb_pred_score_is_lower = imu_scores_gt > imu_scores_rgb
            is_correctable_error = rgb_is_wrong & imu_is_right & rgb_pred_score_is_lower
            num_correctable_errors = np.sum(is_correctable_error)
            prop_correctable = num_correctable_errors / num_rgb_errors

            num_oov = np.sum(gt_seq >= len(train_assemblies))
            num_states = len(gt_seq)

            num_keyframes_total += num_states
            num_rgb_errors_total += num_rgb_errors
            num_correctable_errors_total += num_correctable_errors
            num_oov_total += num_oov
            num_changed_total += num_changed

            logger.info(f"  trial {trial_id}: {num_states} keyframes")
            logger.info(f"    accuracy (fused): {acc * 100:.1f}%")
            logger.info(
                f"    {num_oov} OOV states ({num_oov / num_states * 100:.1f}%)"
            )
            logger.info(
                f"    {num_rgb_errors} RGB errors; "
                f"{num_correctable_errors} correctable from IMU ({prop_correctable * 100:.1f}%)"
            )

            saveVariable(score_seq, f'trial={trial_id}_data-scores')
            saveVariable(pred_assemblies,
                         f'trial={trial_id}_pred-assembly-seq')
            saveVariable(gt_assemblies, f'trial={trial_id}_gt-assembly-seq')

            if plot_predictions:
                io_figs_dir = os.path.join(fig_dir, 'system-io')
                if not os.path.exists(io_figs_dir):
                    os.makedirs(io_figs_dir)
                fn = os.path.join(io_figs_dir, f'trial={trial_id:03}.png')
                utils.plot_array(feature_seq, (gt_seq, pred_seq, score_seq),
                                 ('gt', 'pred', 'scores'),
                                 fn=fn)

                score_figs_dir = os.path.join(fig_dir, 'modality-scores')
                if not os.path.exists(score_figs_dir):
                    os.makedirs(score_figs_dir)
                plot_scores(feature_seq,
                            k=25,
                            fn=os.path.join(score_figs_dir,
                                            f"trial={trial_id:03}.png"))

                paths_dir = os.path.join(fig_dir, 'path-imgs')
                if not os.path.exists(paths_dir):
                    os.makedirs(paths_dir)
                assemblystats.drawPath(pred_seq, trial_id,
                                       f"trial={trial_id}_pred-seq", paths_dir,
                                       assembly_fig_dir)
                assemblystats.drawPath(gt_seq, trial_id,
                                       f"trial={trial_id}_gt-seq", paths_dir,
                                       assembly_fig_dir)

                label_seqs = (gt_seq, ) + tuple(
                    scores.argmax(axis=0) for scores in feature_seq)
                label_seqs = np.row_stack(label_seqs)
                k = 10
                for i, scores in enumerate(feature_seq):
                    label_score_seqs = tuple(
                        np.array([
                            scores[s_idx,
                                   t] if s_idx < scores.shape[0] else -np.inf
                            for t, s_idx in enumerate(label_seq)
                        ]) for label_seq in label_seqs)
                    label_score_seqs = np.row_stack(label_score_seqs)
                    drawPaths(label_seqs,
                              f"trial={trial_id}_pred-scores_modality={i}",
                              paths_dir,
                              assembly_fig_dir,
                              path_scores=label_score_seqs)

                    topk_seq = (-scores).argsort(axis=0)[:k, :]
                    path_scores = np.column_stack(
                        tuple(scores[idxs, i]
                              for i, idxs in enumerate(topk_seq.T)))
                    drawPaths(topk_seq,
                              f"trial={trial_id}_topk_modality={i}",
                              paths_dir,
                              assembly_fig_dir,
                              path_scores=path_scores)
                label_score_seqs = tuple(
                    np.array([
                        score_seq[s_idx,
                                  t] if s_idx < score_seq.shape[0] else -np.inf
                        for t, s_idx in enumerate(label_seq)
                    ]) for label_seq in label_seqs)
                label_score_seqs = np.row_stack(label_score_seqs)
                drawPaths(label_seqs,
                          f"trial={trial_id}_pred-scores_fused",
                          paths_dir,
                          assembly_fig_dir,
                          path_scores=label_score_seqs)
                topk_seq = (-score_seq).argsort(axis=0)[:k, :]
                path_scores = np.column_stack(
                    tuple(score_seq[idxs, i]
                          for i, idxs in enumerate(topk_seq.T)))
                drawPaths(topk_seq,
                          f"trial={trial_id}_topk_fused",
                          paths_dir,
                          assembly_fig_dir,
                          path_scores=path_scores)

        if accuracies:
            fold_accuracy = float(np.array(accuracies).mean())
            # logger.info(f'  acc: {fold_accuracy * 100:.1f}%')
            metric_dict = {'Accuracy': fold_accuracy}
            utils.writeResults(results_file, metric_dict, sweep_param_name,
                               model_params)

    num_unexplained_errors = num_rgb_errors_total - (
        num_oov_total + num_correctable_errors_total)
    prop_correctable = num_correctable_errors_total / num_rgb_errors_total
    prop_oov = num_oov_total / num_rgb_errors_total
    prop_unexplained = num_unexplained_errors / num_rgb_errors_total
    prop_changed = num_changed_total / num_keyframes_total
    logger.info("PERFORMANCE ANALYSIS")
    logger.info(
        f"  {num_rgb_errors_total} / {num_keyframes_total} "
        f"RGB errors ({num_rgb_errors_total / num_keyframes_total * 100:.1f}%)"
    )
    logger.info(f"  {num_oov_total} / {num_rgb_errors_total} "
                f"RGB errors are OOV ({prop_oov * 100:.1f}%)")
    logger.info(f"  {num_correctable_errors_total} / {num_rgb_errors_total} "
                f"RGB errors are correctable ({prop_correctable * 100:.1f}%)")
    logger.info(f"  {num_unexplained_errors} / {num_rgb_errors_total} "
                f"RGB errors are unexplained ({prop_unexplained * 100:.1f}%)")
    logger.info(
        f"  {num_changed_total} / {num_keyframes_total} "
        f"Predictions changed after fusion ({prop_changed * 100:.1f}%)")

    gt_scores = np.hstack(tuple(gt_scores))
    plot_hists(np.exp(gt_scores),
               fn=os.path.join(fig_dir, "score-hists_gt.png"))
    all_scores = np.hstack(tuple(all_scores))
    plot_hists(np.exp(all_scores),
               fn=os.path.join(fig_dir, "score-hists_all.png"))
Exemplo n.º 2
0
def main(out_dir=None,
         data_dir=None,
         segs_dir=None,
         scores_dir=None,
         vocab_dir=None,
         label_type='edges',
         gpu_dev_id=None,
         start_from=None,
         stop_at=None,
         num_disp_imgs=None,
         results_file=None,
         sweep_param_name=None,
         model_params={},
         cv_params={}):

    data_dir = os.path.expanduser(data_dir)
    segs_dir = os.path.expanduser(segs_dir)
    scores_dir = os.path.expanduser(scores_dir)
    vocab_dir = os.path.expanduser(vocab_dir)
    out_dir = os.path.expanduser(out_dir)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    logger = utils.setupRootLogger(filename=os.path.join(out_dir, 'log.txt'))

    if results_file is None:
        results_file = os.path.join(out_dir, 'results.csv')
    else:
        results_file = os.path.expanduser(results_file)

    fig_dir = os.path.join(out_dir, 'figures')
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    io_dir_images = os.path.join(fig_dir, 'model-io_images')
    if not os.path.exists(io_dir_images):
        os.makedirs(io_dir_images)

    io_dir_plots = os.path.join(fig_dir, 'model-io_plots')
    if not os.path.exists(io_dir_plots):
        os.makedirs(io_dir_plots)

    out_data_dir = os.path.join(out_dir, 'data')
    if not os.path.exists(out_data_dir):
        os.makedirs(out_data_dir)

    seq_ids = utils.getUniqueIds(scores_dir,
                                 prefix='trial=',
                                 suffix='score-seq.*',
                                 to_array=True)

    logger.info(
        f"Loaded scores for {len(seq_ids)} sequences from {scores_dir}")

    link_vocab = {}
    joint_vocab = {}
    joint_type_vocab = {}
    vocab, parts_vocab, part_labels = load_vocab(link_vocab, joint_vocab,
                                                 joint_type_vocab, vocab_dir)
    pred_vocab = []  # FIXME

    if label_type == 'assembly':
        logger.info("Converting assemblies -> edges")
        state_pred_seqs = tuple(
            utils.loadVariable(f"trial={seq_id}_pred-label-seq", scores_dir)
            for seq_id in seq_ids)
        state_true_seqs = tuple(
            utils.loadVariable(f"trial={seq_id}_true-label-seq", scores_dir)
            for seq_id in seq_ids)
        edge_pred_seqs = tuple(part_labels[seq] for seq in state_pred_seqs)
        edge_true_seqs = tuple(part_labels[seq] for seq in state_true_seqs)
    elif label_type == 'edge':
        logger.info("Converting edges -> assemblies (will take a few minutes)")
        edge_pred_seqs = tuple(
            utils.loadVariable(f"trial={seq_id}_pred-label-seq", scores_dir)
            for seq_id in seq_ids)
        edge_true_seqs = tuple(
            utils.loadVariable(f"trial={seq_id}_true-label-seq", scores_dir)
            for seq_id in seq_ids)
        state_pred_seqs = tuple(
            edges_to_assemblies(seq, pred_vocab, parts_vocab, part_labels)
            for seq in edge_pred_seqs)
        state_true_seqs = tuple(
            edges_to_assemblies(seq, vocab, parts_vocab, part_labels)
            for seq in edge_true_seqs)

    device = torchutils.selectDevice(gpu_dev_id)
    dataset = sim2real.LabeledConnectionDataset(
        utils.loadVariable('parts-vocab', vocab_dir),
        utils.loadVariable('part-labels', vocab_dir),
        utils.loadVariable('vocab', vocab_dir),
        device=device)

    all_metrics = collections.defaultdict(list)

    # Define cross-validation folds
    cv_folds = utils.makeDataSplits(len(seq_ids), **cv_params)
    utils.saveVariable(cv_folds, 'cv-folds', out_data_dir)

    for cv_index, cv_fold in enumerate(cv_folds):
        train_indices, val_indices, test_indices = cv_fold
        logger.info(
            f"CV FOLD {cv_index + 1} / {len(cv_folds)}: "
            f"{len(train_indices)} train, {len(val_indices)} val, {len(test_indices)} test"
        )

        train_states = np.hstack(
            tuple(state_true_seqs[i] for i in (train_indices)))
        train_edges = part_labels[train_states]
        # state_train_vocab = np.unique(train_states)
        # edge_train_vocab = part_labels[state_train_vocab]
        train_freq_bigram, train_freq_unigram = edge_joint_freqs(train_edges)
        # state_probs = utils.makeHistogram(len(vocab), train_states, normalize=True)

        test_states = np.hstack(
            tuple(state_true_seqs[i] for i in (test_indices)))
        test_edges = part_labels[test_states]
        # state_test_vocab = np.unique(test_states)
        # edge_test_vocab = part_labels[state_test_vocab]
        test_freq_bigram, test_freq_unigram = edge_joint_freqs(test_edges)

        f, axes = plt.subplots(1, 2)
        axes[0].matshow(train_freq_bigram)
        axes[0].set_title('Train')
        axes[1].matshow(test_freq_bigram)
        axes[1].set_title('Test')
        plt.tight_layout()
        plt.savefig(
            os.path.join(fig_dir, f"edge-freqs-bigram_cvfold={cv_index}.png"))

        f, axis = plt.subplots(1)
        axis.stem(train_freq_unigram,
                  label='Train',
                  linefmt='C0-',
                  markerfmt='C0o')
        axis.stem(test_freq_unigram,
                  label='Test',
                  linefmt='C1--',
                  markerfmt='C1o')
        plt.legend()
        plt.tight_layout()
        plt.savefig(
            os.path.join(fig_dir, f"edge-freqs-unigram_cvfold={cv_index}.png"))

        for i in test_indices:
            seq_id = seq_ids[i]
            logger.info(f"  Processing sequence {seq_id}...")

            trial_prefix = f"trial={seq_id}"
            # I include the '.' to differentiate between 'rgb-frame-seq' and
            # 'rgb-frame-seq-before-first-touch'
            # rgb_seq = utils.loadVariable(f"{trial_prefix}_rgb-frame-seq.", data_dir)
            # seg_seq = utils.loadVariable(f"{trial_prefix}_seg-labels-seq", segs_dir)
            score_seq = utils.loadVariable(f"{trial_prefix}_score-seq",
                                           scores_dir)
            # if score_seq.shape[0] != rgb_seq.shape[0]:
            #     err_str = f"scores shape {score_seq.shape} != data shape {rgb_seq.shape}"
            #     raise AssertionError(err_str)

            edge_pred_seq = edge_pred_seqs[i]
            edge_true_seq = edge_true_seqs[i]
            state_pred_seq = state_pred_seqs[i]
            state_true_seq = state_true_seqs[i]

            num_types = np.unique(state_pred_seq).shape[0]
            num_samples = state_pred_seq.shape[0]
            num_total = len(pred_vocab)
            logger.info(
                f"    {num_types} assemblies predicted ({num_total} total); "
                f"{num_samples} samples")

            # edge_freq_bigram, edge_freq_unigram = edge_joint_freqs(edge_true_seq)
            # dist_shift = np.linalg.norm(train_freq_unigram - edge_freq_unigram)
            metric_dict = {
                # 'State OOV rate': oov_rate_state(state_true_seq, state_train_vocab),
                # 'Edge OOV rate': oov_rate_edges(edge_true_seq, edge_train_vocab),
                # 'State avg prob, true': state_probs[state_true_seq].mean(),
                # 'State avg prob, pred': state_probs[state_pred_seq].mean(),
                # 'Edge distribution shift': dist_shift
            }
            metric_dict = eval_edge_metrics(edge_pred_seq,
                                            edge_true_seq,
                                            append_to=metric_dict)
            metric_dict = eval_state_metrics(state_pred_seq,
                                             state_true_seq,
                                             append_to=metric_dict)
            for name, value in metric_dict.items():
                logger.info(f"    {name}: {value * 100:.2f}%")
                all_metrics[name].append(value)

            utils.writeResults(results_file, metric_dict, sweep_param_name,
                               model_params)

            if num_disp_imgs is not None:
                pred_images = tuple(
                    render(dataset, vocab[seg_label])
                    for seg_label in utils.computeSegments(state_pred_seq)[0])
                imageprocessing.displayImages(
                    *pred_images,
                    file_path=os.path.join(
                        io_dir_images,
                        f"seq={seq_id:03d}_pred-assemblies.png"),
                    num_rows=None,
                    num_cols=5)
                true_images = tuple(
                    render(dataset, vocab[seg_label])
                    for seg_label in utils.computeSegments(state_true_seq)[0])
                imageprocessing.displayImages(
                    *true_images,
                    file_path=os.path.join(
                        io_dir_images,
                        f"seq={seq_id:03d}_true-assemblies.png"),
                    num_rows=None,
                    num_cols=5)

                utils.plot_array(score_seq.T,
                                 (edge_true_seq.T, edge_pred_seq.T),
                                 ('true', 'pred'),
                                 fn=os.path.join(io_dir_plots,
                                                 f"seq={seq_id:03d}.png"))
Exemplo n.º 3
0
def main(out_dir=None,
         data_dir=None,
         attr_dir=None,
         model_name=None,
         gpu_dev_id=None,
         batch_size=None,
         learning_rate=None,
         model_params={},
         cv_params={},
         train_params={},
         viz_params={},
         plot_predictions=None,
         results_file=None,
         sweep_param_name=None):

    data_dir = os.path.expanduser(data_dir)
    out_dir = os.path.expanduser(out_dir)
    attr_dir = os.path.expanduser(attr_dir)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    logger = utils.setupRootLogger(filename=os.path.join(out_dir, 'log.txt'))

    if results_file is None:
        results_file = os.path.join(out_dir, f'results.csv')
    else:
        results_file = os.path.expanduser(results_file)

    fig_dir = os.path.join(out_dir, 'figures')
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    out_data_dir = os.path.join(out_dir, 'data')
    if not os.path.exists(out_data_dir):
        os.makedirs(out_data_dir)

    def loadData(seq_id):
        var_name = f"trial-{seq_id}_rgb-frame-seq"
        data = joblib.load(os.path.join(data_dir, f'{var_name}.pkl'))
        return data.swapaxes(1, 3)

    def loadLabels(seq_id):
        var_name = f"trial-{seq_id}_label-seq"
        return joblib.load(os.path.join(attr_dir, f'{var_name}.pkl'))

    def saveVariable(var, var_name):
        joblib.dump(var, os.path.join(out_data_dir, f'{var_name}.pkl'))

    # Load data
    trial_ids = utils.getUniqueIds(data_dir)
    label_seqs = tuple(map(loadLabels, trial_ids))

    device = torchutils.selectDevice(gpu_dev_id)

    # Define cross-validation folds
    dataset_size = len(trial_ids)
    cv_folds = utils.makeDataSplits(dataset_size, **cv_params)

    def getSplit(split_idxs):
        split_data = tuple(
            tuple(s[i] for i in split_idxs) for s in (label_seqs, trial_ids))
        return split_data

    for cv_index, cv_splits in enumerate(cv_folds):
        train_data, val_data, test_data = tuple(map(getSplit, cv_splits))

        criterion = torch.nn.BCEWithLogitsLoss()
        labels_dtype = torch.float

        train_labels, train_ids = train_data
        train_set = torchutils.PickledVideoDataset(loadData,
                                                   train_labels,
                                                   device=device,
                                                   labels_dtype=labels_dtype,
                                                   seq_ids=train_ids,
                                                   batch_size=batch_size)
        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=1,
                                                   shuffle=True)

        test_labels, test_ids = test_data
        test_set = torchutils.PickledVideoDataset(loadData,
                                                  test_labels,
                                                  device=device,
                                                  labels_dtype=labels_dtype,
                                                  seq_ids=test_ids,
                                                  batch_size=batch_size)
        test_loader = torch.utils.data.DataLoader(test_set,
                                                  batch_size=1,
                                                  shuffle=False)

        val_labels, val_ids = val_data
        val_set = torchutils.PickledVideoDataset(loadData,
                                                 val_labels,
                                                 device=device,
                                                 labels_dtype=labels_dtype,
                                                 seq_ids=val_ids,
                                                 batch_size=batch_size)
        val_loader = torch.utils.data.DataLoader(val_set,
                                                 batch_size=1,
                                                 shuffle=True)

        logger.info(
            f'CV fold {cv_index + 1} / {len(cv_folds)}: {len(trial_ids)} total '
            f'({len(train_ids)} train, {len(val_ids)} val, {len(test_ids)} test)'
        )

        if model_name == 'resnet':
            # input_dim = train_set.num_obsv_dims
            output_dim = train_set.num_label_types
            model = ImageClassifier(output_dim,
                                    **model_params).to(device=device)
        else:
            raise AssertionError()

        train_epoch_log = collections.defaultdict(list)
        val_epoch_log = collections.defaultdict(list)
        metric_dict = {
            'Avg Loss': metrics.AverageLoss(),
            'Accuracy': metrics.Accuracy(),
            'Precision': metrics.Precision(),
            'Recall': metrics.Recall(),
            'F1': metrics.Fmeasure()
        }

        optimizer_ft = torch.optim.Adam(model.parameters(),
                                        lr=learning_rate,
                                        betas=(0.9, 0.999),
                                        eps=1e-08,
                                        weight_decay=0,
                                        amsgrad=False)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft,
                                                       step_size=1,
                                                       gamma=1.00)

        model, last_model_wts = torchutils.trainModel(
            model,
            criterion,
            optimizer_ft,
            lr_scheduler,
            train_loader,
            val_loader,
            device=device,
            metrics=metric_dict,
            train_epoch_log=train_epoch_log,
            val_epoch_log=val_epoch_log,
            **train_params)

        # Test model
        metric_dict = {
            'Avg Loss': metrics.AverageLoss(),
            'Accuracy': metrics.Accuracy(),
            'Precision': metrics.Precision(),
            'Recall': metrics.Recall(),
            'F1': metrics.Fmeasure()
        }
        test_io_history = torchutils.predictSamples(
            model.to(device=device),
            test_loader,
            criterion=criterion,
            device=device,
            metrics=metric_dict,
            data_labeled=True,
            update_model=False,
            seq_as_batch=train_params['seq_as_batch'],
            return_io_history=True)
        metric_str = '  '.join(str(m) for m in metric_dict.values())
        logger.info('[TST]  ' + metric_str)

        utils.writeResults(results_file, metric_dict, sweep_param_name,
                           model_params)

        if plot_predictions:
            # imu.plot_prediction_eg(test_io_history, fig_dir, fig_type=fig_type, **viz_params)
            imu.plot_prediction_eg(test_io_history, fig_dir, **viz_params)

        def saveTrialData(pred_seq, score_seq, feat_seq, label_seq, trial_id):
            saveVariable(pred_seq.cpu().numpy(),
                         f'trial={trial_id}_pred-label-seq')
            saveVariable(score_seq.cpu().numpy(),
                         f'trial={trial_id}_score-seq')
            saveVariable(label_seq.cpu().numpy(),
                         f'trial={trial_id}_true-label-seq')

        for io in test_io_history:
            saveTrialData(*io)

        saveVariable(train_ids, f'cvfold={cv_index}_train-ids')
        saveVariable(test_ids, f'cvfold={cv_index}_test-ids')
        saveVariable(val_ids, f'cvfold={cv_index}_val-ids')
        saveVariable(train_epoch_log,
                     f'cvfold={cv_index}_{model_name}-train-epoch-log')
        saveVariable(val_epoch_log,
                     f'cvfold={cv_index}_{model_name}-val-epoch-log')
        saveVariable(metric_dict,
                     f'cvfold={cv_index}_{model_name}-metric-dict')
        saveVariable(model, f'cvfold={cv_index}_{model_name}-best')

        model.load_state_dict(last_model_wts)
        saveVariable(model, f'cvfold={cv_index}_{model_name}-last')

        torchutils.plotEpochLog(train_epoch_log,
                                subfig_size=(10, 2.5),
                                title='Training performance',
                                fn=os.path.join(
                                    fig_dir,
                                    f'cvfold={cv_index}_train-plot.png'))

        if val_epoch_log:
            torchutils.plotEpochLog(val_epoch_log,
                                    subfig_size=(10, 2.5),
                                    title='Heldout performance',
                                    fn=os.path.join(
                                        fig_dir,
                                        f'cvfold={cv_index}_val-plot.png'))
def main(out_dir=None,
         data_dir=None,
         model_name=None,
         gpu_dev_id=None,
         batch_size=None,
         learning_rate=None,
         model_params={},
         cv_params={},
         train_params={},
         viz_params={},
         load_masks_params={},
         kornia_tfs={},
         only_edge=None,
         num_disp_imgs=None,
         results_file=None,
         sweep_param_name=None):

    data_dir = os.path.expanduser(data_dir)
    out_dir = os.path.expanduser(out_dir)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    logger = utils.setupRootLogger(filename=os.path.join(out_dir, 'log.txt'))

    if results_file is None:
        results_file = os.path.join(out_dir, 'results.csv')
    else:
        results_file = os.path.expanduser(results_file)

    fig_dir = os.path.join(out_dir, 'figures')
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    io_dir = os.path.join(fig_dir, 'model-io')
    if not os.path.exists(io_dir):
        os.makedirs(io_dir)

    out_data_dir = os.path.join(out_dir, 'data')
    if not os.path.exists(out_data_dir):
        os.makedirs(out_data_dir)

    def saveVariable(var, var_name, to_dir=out_data_dir):
        return utils.saveVariable(var, var_name, to_dir)

    trial_ids = utils.getUniqueIds(data_dir, prefix='trial=', to_array=True)

    vocab = [BlockAssembly()
             ] + [make_single_block_state(i) for i in range(len(defn.blocks))]
    for seq_id in trial_ids:
        assembly_seq = utils.loadVariable(f"trial={seq_id}_assembly-seq",
                                          data_dir)
        for assembly in assembly_seq:
            utils.getIndex(assembly, vocab)
    parts_vocab, part_labels = labels_lib.make_parts_vocab(
        vocab, lower_tri_only=True, append_to_vocab=True)

    if only_edge is not None:
        part_labels = part_labels[:, only_edge:only_edge + 1]

    logger.info(
        f"Loaded {len(trial_ids)} sequences; {len(vocab)} unique assemblies")

    saveVariable(vocab, 'vocab')
    saveVariable(parts_vocab, 'parts-vocab')
    saveVariable(part_labels, 'part-labels')

    device = torchutils.selectDevice(gpu_dev_id)

    if model_name == 'AAE':
        Dataset = sim2real.DenoisingDataset
    elif model_name == 'Resnet':
        Dataset = sim2real.RenderDataset
    elif model_name == 'Connections':
        Dataset = sim2real.ConnectionDataset
    elif model_name == 'Labeled Connections':
        Dataset = sim2real.LabeledConnectionDataset

    occlusion_masks = loadMasks(**load_masks_params)
    if occlusion_masks is not None:
        logger.info(f"Loaded {occlusion_masks.shape[0]} occlusion masks")

    def make_data(shuffle=True):
        dataset = Dataset(
            parts_vocab,
            part_labels,
            vocab,
            device=device,
            occlusion_masks=occlusion_masks,
            kornia_tfs=kornia_tfs,
        )
        data_loader = torch.utils.data.DataLoader(dataset,
                                                  batch_size=batch_size,
                                                  shuffle=shuffle)
        return dataset, data_loader

    for cv_index, cv_splits in enumerate(range(1)):
        cv_str = f"cvfold={cv_index}"

        train_set, train_loader = make_data(shuffle=True)
        test_set, test_loader = make_data(shuffle=False)
        val_set, val_loader = make_data(shuffle=True)

        if model_name == 'AAE':
            model = sim2real.AugmentedAutoEncoder(train_set.data_shape,
                                                  train_set.num_classes)
            criterion = torchutils.BootstrappedCriterion(
                0.25,
                base_criterion=torch.nn.functional.mse_loss,
            )
            metric_names = ('Reciprocal Loss', )
        elif model_name == 'Resnet':
            model = sim2real.ImageClassifier(train_set.num_classes,
                                             **model_params)
            criterion = torch.nn.CrossEntropyLoss()
            metric_names = ('Loss', 'Accuracy')
        elif model_name == 'Connections':
            model = sim2real.ConnectionClassifier(train_set.label_shape[0],
                                                  **model_params)
            criterion = torch.nn.BCEWithLogitsLoss()
            metric_names = ('Loss', 'Accuracy', 'Precision', 'Recall', 'F1')
        elif model_name == 'Labeled Connections':
            out_dim = int(part_labels.max()) + 1
            num_vertices = len(defn.blocks)
            edges = np.column_stack(np.tril_indices(num_vertices, k=-1))
            if only_edge is not None:
                edges = edges[only_edge:only_edge + 1]
            model = sim2real.LabeledConnectionClassifier(
                out_dim, num_vertices, edges, **model_params)
            if only_edge is not None:
                logger.info(f"Class freqs: {train_set.class_freqs}")
                # criterion = torch.nn.CrossEntropyLoss(weight=1 / train_set.class_freqs[:, 0])
                criterion = torch.nn.CrossEntropyLoss()
            else:
                criterion = torch.nn.CrossEntropyLoss()
            # criterion = torchutils.BootstrappedCriterion(
            #     0.25, base_criterion=torch.nn.functional.cross_entropy,
            # )
            metric_names = ('Loss', 'Accuracy', 'Precision', 'Recall', 'F1')

        model = model.to(device=device)

        optimizer_ft = torch.optim.Adam(model.parameters(),
                                        lr=learning_rate,
                                        betas=(0.9, 0.999),
                                        eps=1e-08,
                                        weight_decay=0,
                                        amsgrad=False)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft,
                                                       step_size=1,
                                                       gamma=1.00)

        train_epoch_log = collections.defaultdict(list)
        val_epoch_log = collections.defaultdict(list)
        metric_dict = {name: metrics.makeMetric(name) for name in metric_names}
        model, last_model_wts = torchutils.trainModel(
            model,
            criterion,
            optimizer_ft,
            lr_scheduler,
            train_loader,
            val_loader,
            device=device,
            metrics=metric_dict,
            train_epoch_log=train_epoch_log,
            val_epoch_log=val_epoch_log,
            **train_params)

        # Test model
        metric_dict = {name: metrics.makeMetric(name) for name in metric_names}
        test_io_batches = torchutils.predictSamples(
            model.to(device=device),
            test_loader,
            criterion=criterion,
            device=device,
            metrics=metric_dict,
            data_labeled=True,
            update_model=False,
            seq_as_batch=train_params['seq_as_batch'],
            return_io_history=True)
        metric_str = '  '.join(str(m) for m in metric_dict.values())
        logger.info('[TST]  ' + metric_str)
        utils.writeResults(results_file, metric_dict, sweep_param_name,
                           model_params)

        for pred_seq, score_seq, feat_seq, label_seq, trial_id in test_io_batches:
            trial_str = f"trial={trial_id}"
            saveVariable(pred_seq.cpu().numpy(), f'{trial_str}_pred-label-seq')
            saveVariable(score_seq.cpu().numpy(), f'{trial_str}_score-seq')
            saveVariable(label_seq.cpu().numpy(),
                         f'{trial_str}_true-label-seq')

        saveVariable(model, f'{cv_str}_model-best')

        if train_epoch_log:
            torchutils.plotEpochLog(train_epoch_log,
                                    subfig_size=(10, 2.5),
                                    title='Training performance',
                                    fn=os.path.join(
                                        fig_dir, f'{cv_str}_train-plot.png'))

        if val_epoch_log:
            torchutils.plotEpochLog(val_epoch_log,
                                    subfig_size=(10, 2.5),
                                    title='Heldout performance',
                                    fn=os.path.join(fig_dir,
                                                    f'{cv_str}_val-plot.png'))

        if num_disp_imgs is not None:
            model.plotBatches(test_io_batches, io_dir, dataset=test_set)
Exemplo n.º 5
0
def main(out_dir=None,
         rgb_data_dir=None,
         rgb_attributes_dir=None,
         rgb_vocab_dir=None,
         imu_data_dir=None,
         imu_attributes_dir=None,
         modalities=['rgb', 'imu'],
         gpu_dev_id=None,
         plot_predictions=None,
         results_file=None,
         sweep_param_name=None,
         model_params={},
         cv_params={},
         train_params={},
         viz_params={}):

    out_dir = os.path.expanduser(out_dir)
    rgb_data_dir = os.path.expanduser(rgb_data_dir)
    rgb_attributes_dir = os.path.expanduser(rgb_attributes_dir)
    rgb_vocab_dir = os.path.expanduser(rgb_vocab_dir)
    imu_data_dir = os.path.expanduser(imu_data_dir)
    imu_attributes_dir = os.path.expanduser(imu_attributes_dir)

    logger = utils.setupRootLogger(filename=os.path.join(out_dir, 'log.txt'))

    if results_file is None:
        results_file = os.path.join(out_dir, 'results.csv')
    else:
        results_file = os.path.expanduser(results_file)

    fig_dir = os.path.join(out_dir, 'figures')
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    out_data_dir = os.path.join(out_dir, 'data')
    if not os.path.exists(out_data_dir):
        os.makedirs(out_data_dir)

    def saveVariable(var, var_name, to_dir=out_data_dir):
        utils.saveVariable(var, var_name, to_dir)

    # Load data
    if modalities == ['rgb']:
        trial_ids = utils.getUniqueIds(rgb_data_dir,
                                       prefix='trial=',
                                       to_array=True)
        logger.info(f"Processing {len(trial_ids)} videos")
    else:
        rgb_trial_ids = utils.getUniqueIds(rgb_data_dir,
                                           prefix='trial=',
                                           to_array=True)
        imu_trial_ids = utils.getUniqueIds(imu_data_dir,
                                           prefix='trial=',
                                           to_array=True)
        trial_ids = np.array(
            sorted(set(rgb_trial_ids.tolist()) & set(imu_trial_ids.tolist())))
        logger.info(
            f"Processing {len(trial_ids)} videos common to "
            f"RGB ({len(rgb_trial_ids)} total) and IMU ({len(imu_trial_ids)} total)"
        )

    device = torchutils.selectDevice(gpu_dev_id)
    dataset = FusionDataset(trial_ids,
                            rgb_attributes_dir,
                            rgb_data_dir,
                            imu_attributes_dir,
                            imu_data_dir,
                            device=device,
                            modalities=modalities)
    utils.saveMetadata(dataset.metadata, out_data_dir)
    saveVariable(dataset.vocab, 'vocab')

    # parts_vocab = loadVariable('parts-vocab')
    edge_labels = {
        'rgb':
        utils.loadVariable('part-labels', rgb_vocab_dir),
        'imu':
        np.stack([
            labels.inSameComponent(a, lower_tri_only=True)
            for a in dataset.vocab
        ])
    }
    # edge_labels = revise_edge_labels(edge_labels, input_seqs)

    attribute_labels = tuple(edge_labels[name] for name in modalities)

    logger.info('Making transition probs...')
    transition_probs = make_transition_scores(dataset.vocab)
    saveVariable(transition_probs, 'transition-probs')

    model = AttributeModel(*attribute_labels, device=device)

    if plot_predictions:
        figsize = (12, 3)
        fig, axis = plt.subplots(1, figsize=figsize)
        axis.imshow(edge_labels['rgb'].T, interpolation='none', aspect='auto')
        plt.savefig(os.path.join(fig_dir, "edge-labels.png"))
        plt.close()

    for i, trial_id in enumerate(trial_ids):
        logger.info(f"Processing sequence {trial_id}...")

        trial_prefix = f"trial={trial_id}"

        true_label_seq = dataset.loadTargets(trial_id)
        attribute_feats = dataset.loadInputs(trial_id)

        score_seq = model(attribute_feats)
        pred_label_seq = model.predict(score_seq)

        attribute_feats = attribute_feats.cpu().numpy()
        score_seq = score_seq.cpu().numpy()
        true_label_seq = true_label_seq.cpu().numpy()
        pred_label_seq = pred_label_seq.cpu().numpy()

        saveVariable(score_seq.T, f'{trial_prefix}_score-seq')
        saveVariable(true_label_seq.T, f'{trial_prefix}_label-seq')

        if plot_predictions:
            fn = os.path.join(fig_dir, f'{trial_prefix}.png')
            utils.plot_array(attribute_feats.T,
                             (true_label_seq, pred_label_seq, score_seq),
                             ('gt', 'pred', 'scores'),
                             fn=fn)

        metric_dict = eval_metrics(pred_label_seq, true_label_seq)
        for name, value in metric_dict.items():
            logger.info(f"  {name}: {value * 100:.2f}%")

        utils.writeResults(results_file, metric_dict, sweep_param_name,
                           model_params)
Exemplo n.º 6
0
def main(out_dir=None,
         data_dir=None,
         model_name=None,
         predict_mode='classify',
         gpu_dev_id=None,
         batch_size=None,
         learning_rate=None,
         independent_signals=None,
         active_only=None,
         output_dim_from_vocab=False,
         prefix='trial=',
         feature_fn_format='feature-seq.pkl',
         label_fn_format='label_seq.pkl',
         dataset_params={},
         model_params={},
         cv_params={},
         train_params={},
         viz_params={},
         metric_names=['Loss', 'Accuracy', 'Precision', 'Recall', 'F1'],
         plot_predictions=None,
         results_file=None,
         sweep_param_name=None):

    data_dir = os.path.expanduser(data_dir)
    out_dir = os.path.expanduser(out_dir)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    logger = utils.setupRootLogger(filename=os.path.join(out_dir, 'log.txt'))

    if results_file is None:
        results_file = os.path.join(out_dir, 'results.csv')
    else:
        results_file = os.path.expanduser(results_file)

    fig_dir = os.path.join(out_dir, 'figures')
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    out_data_dir = os.path.join(out_dir, 'data')
    if not os.path.exists(out_data_dir):
        os.makedirs(out_data_dir)

    def saveVariable(var, var_name, to_dir=out_data_dir):
        return utils.saveVariable(var, var_name, to_dir)

    # Load data
    device = torchutils.selectDevice(gpu_dev_id)
    trial_ids = utils.getUniqueIds(data_dir,
                                   prefix=prefix,
                                   suffix=feature_fn_format,
                                   to_array=True)
    dataset = utils.CvDataset(
        trial_ids,
        data_dir,
        prefix=prefix,
        feature_fn_format=feature_fn_format,
        label_fn_format=label_fn_format,
    )
    utils.saveMetadata(dataset.metadata, out_data_dir)
    utils.saveVariable(dataset.vocab, 'vocab', out_data_dir)

    # Define cross-validation folds
    cv_folds = utils.makeDataSplits(len(trial_ids), **cv_params)
    utils.saveVariable(cv_folds, 'cv-folds', out_data_dir)

    if predict_mode == 'binary multiclass':
        # criterion = torchutils.BootstrappedCriterion(
        #     0.25, base_criterion=torch.nn.functional.binary_cross_entropy_with_logits,
        # )
        criterion = torch.nn.BCEWithLogitsLoss()
        labels_dtype = torch.float
    elif predict_mode == 'multiclass':
        criterion = torch.nn.CrossEntropyLoss()
        labels_dtype = torch.long
    elif predict_mode == 'classify':
        criterion = torch.nn.CrossEntropyLoss()
        labels_dtype = torch.long
    else:
        raise AssertionError()

    def make_dataset(feats, labels, ids, shuffle=True):
        dataset = torchutils.SequenceDataset(feats,
                                             labels,
                                             device=device,
                                             labels_dtype=labels_dtype,
                                             seq_ids=ids,
                                             **dataset_params)
        loader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             shuffle=True)
        return dataset, loader

    for cv_index, cv_fold in enumerate(cv_folds):
        train_data, val_data, test_data = dataset.getFold(cv_fold)
        if independent_signals:
            train_data = splitSeqs(*train_data, active_only=active_only)
            val_data = splitSeqs(*val_data, active_only=active_only)
            test_data = splitSeqs(*test_data, active_only=False)
        train_set, train_loader = make_dataset(*train_data, shuffle=True)
        test_set, test_loader = make_dataset(*test_data, shuffle=False)
        val_set, val_loader = make_dataset(*val_data, shuffle=True)

        logger.info(
            f'CV fold {cv_index + 1} / {len(cv_folds)}: {len(dataset.trial_ids)} total '
            f'({len(train_set)} train, {len(val_set)} val, {len(test_set)} test)'
        )

        logger.info(f'{train_set.num_label_types} unique labels in train set; '
                    f'vocab size is {len(dataset.vocab)}')

        input_dim = train_set.num_obsv_dims
        output_dim = train_set.num_label_types
        if output_dim_from_vocab:
            output_dim = len(dataset.vocab)

        if model_name == 'linear':
            model = torchutils.LinearClassifier(
                input_dim, output_dim, **model_params).to(device=device)
        elif model_name == 'conv':
            model = ConvClassifier(input_dim, output_dim,
                                   **model_params).to(device=device)
        elif model_name == 'TCN':
            if predict_mode == 'multiclass':
                num_multiclass = train_set[0][1].shape[-1]
                output_dim = max([
                    train_set.num_label_types, test_set.num_label_types,
                    val_set.num_label_types
                ])
            else:
                num_multiclass = None
            model = TcnClassifier(input_dim,
                                  output_dim,
                                  num_multiclass=num_multiclass,
                                  **model_params).to(device=device)
        elif model_name == 'LSTM':
            if predict_mode == 'multiclass':
                num_multiclass = train_set[0][1].shape[-1]
                output_dim = max([
                    train_set.num_label_types, test_set.num_label_types,
                    val_set.num_label_types
                ])
            else:
                num_multiclass = None
            model = LstmClassifier(input_dim,
                                   output_dim,
                                   num_multiclass=num_multiclass,
                                   **model_params).to(device=device)
        else:
            raise AssertionError()

        optimizer_ft = torch.optim.Adam(model.parameters(),
                                        lr=learning_rate,
                                        betas=(0.9, 0.999),
                                        eps=1e-08,
                                        weight_decay=0,
                                        amsgrad=False)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft,
                                                       step_size=1,
                                                       gamma=1.00)

        train_epoch_log = collections.defaultdict(list)
        val_epoch_log = collections.defaultdict(list)
        metric_dict = {name: metrics.makeMetric(name) for name in metric_names}
        model, last_model_wts = torchutils.trainModel(
            model,
            criterion,
            optimizer_ft,
            lr_scheduler,
            train_loader,
            val_loader,
            device=device,
            metrics=metric_dict,
            train_epoch_log=train_epoch_log,
            val_epoch_log=val_epoch_log,
            **train_params)

        # Test model
        metric_dict = {name: metrics.makeMetric(name) for name in metric_names}
        test_io_history = torchutils.predictSamples(
            model.to(device=device),
            test_loader,
            criterion=criterion,
            device=device,
            metrics=metric_dict,
            data_labeled=True,
            update_model=False,
            seq_as_batch=train_params['seq_as_batch'],
            return_io_history=True)
        if independent_signals:
            test_io_history = tuple(joinSeqs(test_io_history))

        logger.info('[TST]  ' +
                    '  '.join(str(m) for m in metric_dict.values()))
        utils.writeResults(results_file,
                           {k: v.value
                            for k, v in metric_dict.items()}, sweep_param_name,
                           model_params)

        if plot_predictions:
            io_fig_dir = os.path.join(fig_dir, 'model-io')
            if not os.path.exists(io_fig_dir):
                os.makedirs(io_fig_dir)

            label_names = ('gt', 'pred')
            preds, scores, inputs, gt_labels, ids = zip(*test_io_history)
            for batch in test_io_history:
                batch = tuple(
                    x.cpu().numpy() if isinstance(x, torch.Tensor) else x
                    for x in batch)
                for preds, _, inputs, gt_labels, seq_id in zip(*batch):
                    fn = os.path.join(io_fig_dir,
                                      f"{prefix}{seq_id}_model-io.png")
                    utils.plot_array(inputs, (gt_labels.T, preds.T),
                                     label_names,
                                     fn=fn)

        for batch in test_io_history:
            batch = tuple(x.cpu().numpy() if isinstance(x, torch.Tensor) else x
                          for x in batch)
            for pred_seq, score_seq, feat_seq, label_seq, trial_id in zip(
                    *batch):
                saveVariable(pred_seq, f'{prefix}{trial_id}_pred-label-seq')
                saveVariable(score_seq, f'{prefix}{trial_id}_score-seq')
                saveVariable(label_seq, f'{prefix}{trial_id}_true-label-seq')

        saveVariable(model, f'cvfold={cv_index}_{model_name}-best')

        train_fig_dir = os.path.join(fig_dir, 'train-plots')
        if not os.path.exists(train_fig_dir):
            os.makedirs(train_fig_dir)

        if train_epoch_log:
            torchutils.plotEpochLog(train_epoch_log,
                                    subfig_size=(10, 2.5),
                                    title='Training performance',
                                    fn=os.path.join(
                                        train_fig_dir,
                                        f'cvfold={cv_index}_train-plot.png'))

        if val_epoch_log:
            torchutils.plotEpochLog(val_epoch_log,
                                    subfig_size=(10, 2.5),
                                    title='Heldout performance',
                                    fn=os.path.join(
                                        train_fig_dir,
                                        f'cvfold={cv_index}_val-plot.png'))
Exemplo n.º 7
0
def main(out_dir=None,
         data_dir=None,
         model_name=None,
         pretrained_model_dir=None,
         gpu_dev_id=None,
         batch_size=None,
         learning_rate=None,
         independent_signals=None,
         active_only=None,
         model_params={},
         cv_params={},
         train_params={},
         viz_params={},
         plot_predictions=None,
         results_file=None,
         sweep_param_name=None,
         label_mapping=None,
         eval_label_mapping=None):

    data_dir = os.path.expanduser(data_dir)
    out_dir = os.path.expanduser(out_dir)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    logger = utils.setupRootLogger(filename=os.path.join(out_dir, 'log.txt'))

    if results_file is None:
        results_file = os.path.join(out_dir, 'results.csv')
    else:
        results_file = os.path.expanduser(results_file)

    fig_dir = os.path.join(out_dir, 'figures')
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    out_data_dir = os.path.join(out_dir, 'data')
    if not os.path.exists(out_data_dir):
        os.makedirs(out_data_dir)

    def saveVariable(var, var_name):
        joblib.dump(var, os.path.join(out_data_dir, f'{var_name}.pkl'))

    def loadAll(seq_ids, var_name, data_dir):
        def loadOne(seq_id):
            fn = os.path.join(data_dir, f'trial={seq_id}_{var_name}')
            return joblib.load(fn)

        return tuple(map(loadOne, seq_ids))

    # Load data
    trial_ids = utils.getUniqueIds(data_dir, prefix='trial=', to_array=True)
    feature_seqs = loadAll(trial_ids, 'feature-seq.pkl', data_dir)
    label_seqs = loadAll(trial_ids, 'label-seq.pkl', data_dir)

    device = torchutils.selectDevice(gpu_dev_id)

    if label_mapping is not None:

        def map_labels(labels):
            for i, j in label_mapping.items():
                labels[labels == i] = j
            return labels

        label_seqs = tuple(map(map_labels, label_seqs))

    # Define cross-validation folds
    dataset_size = len(trial_ids)
    cv_folds = utils.makeDataSplits(dataset_size, **cv_params)

    def getSplit(split_idxs):
        split_data = tuple(
            tuple(s[i] for i in split_idxs)
            for s in (feature_seqs, label_seqs, trial_ids))
        return split_data

    for cv_index, cv_splits in enumerate(cv_folds):
        if pretrained_model_dir is not None:

            def loadFromPretrain(fn):
                return joblib.load(
                    os.path.join(pretrained_model_dir, f"{fn}.pkl"))

            model = loadFromPretrain(f'cvfold={cv_index}_{model_name}-best')
            train_ids = loadFromPretrain(f'cvfold={cv_index}_train-ids')
            val_ids = loadFromPretrain(f'cvfold={cv_index}_val-ids')
            test_ids = tuple(i for i in trial_ids
                             if i not in (train_ids + val_ids))
            test_idxs = tuple(trial_ids.tolist().index(i) for i in test_ids)
            test_data = getSplit(test_idxs)

            if independent_signals:
                criterion = torch.nn.CrossEntropyLoss()
                labels_dtype = torch.long
                test_data = splitSeqs(*test_data, active_only=False)
            else:
                # FIXME
                # criterion = torch.nn.BCEWithLogitsLoss()
                # labels_dtype = torch.float
                criterion = torch.nn.CrossEntropyLoss()
                labels_dtype = torch.long

            test_feats, test_labels, test_ids = test_data
            test_set = torchutils.SequenceDataset(test_feats,
                                                  test_labels,
                                                  device=device,
                                                  labels_dtype=labels_dtype,
                                                  seq_ids=test_ids,
                                                  transpose_data=True)
            test_loader = torch.utils.data.DataLoader(test_set,
                                                      batch_size=batch_size,
                                                      shuffle=False)

            # Test model
            metric_dict = {
                'Avg Loss': metrics.AverageLoss(),
                'Accuracy': metrics.Accuracy(),
                'Precision': metrics.Precision(),
                'Recall': metrics.Recall(),
                'F1': metrics.Fmeasure()
            }
            test_io_history = torchutils.predictSamples(
                model.to(device=device),
                test_loader,
                criterion=criterion,
                device=device,
                metrics=metric_dict,
                data_labeled=True,
                update_model=False,
                seq_as_batch=train_params['seq_as_batch'],
                return_io_history=True)
            if independent_signals:
                test_io_history = tuple(joinSeqs(test_io_history))

            metric_str = '  '.join(str(m) for m in metric_dict.values())
            logger.info('[TST]  ' + metric_str)

            d = {k: v.value for k, v in metric_dict.items()}
            utils.writeResults(results_file, d, sweep_param_name, model_params)

            if plot_predictions:
                imu.plot_prediction_eg(test_io_history, fig_dir, **viz_params)

            def saveTrialData(pred_seq, score_seq, feat_seq, label_seq,
                              trial_id):
                if label_mapping is not None:

                    def dup_score_cols(scores):
                        num_cols = scores.shape[-1] + len(label_mapping)
                        col_idxs = torch.arange(num_cols)
                        for i, j in label_mapping.items():
                            col_idxs[i] = j
                        return scores[..., col_idxs]

                    score_seq = dup_score_cols(score_seq)
                saveVariable(pred_seq.cpu().numpy(),
                             f'trial={trial_id}_pred-label-seq')
                saveVariable(score_seq.cpu().numpy(),
                             f'trial={trial_id}_score-seq')
                saveVariable(label_seq.cpu().numpy(),
                             f'trial={trial_id}_true-label-seq')

            for io in test_io_history:
                saveTrialData(*io)
            continue

        train_data, val_data, test_data = tuple(map(getSplit, cv_splits))

        if independent_signals:
            criterion = torch.nn.CrossEntropyLoss()
            labels_dtype = torch.long
            split_ = functools.partial(splitSeqs, active_only=active_only)
            train_data = split_(*train_data)
            val_data = split_(*val_data)
            test_data = splitSeqs(*test_data, active_only=False)
        else:
            # FIXME
            # criterion = torch.nn.BCEWithLogitsLoss()
            # labels_dtype = torch.float
            criterion = torch.nn.CrossEntropyLoss()
            labels_dtype = torch.long

        train_feats, train_labels, train_ids = train_data
        train_set = torchutils.SequenceDataset(train_feats,
                                               train_labels,
                                               device=device,
                                               labels_dtype=labels_dtype,
                                               seq_ids=train_ids,
                                               transpose_data=True)
        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=batch_size,
                                                   shuffle=True)

        test_feats, test_labels, test_ids = test_data
        test_set = torchutils.SequenceDataset(test_feats,
                                              test_labels,
                                              device=device,
                                              labels_dtype=labels_dtype,
                                              seq_ids=test_ids,
                                              transpose_data=True)
        test_loader = torch.utils.data.DataLoader(test_set,
                                                  batch_size=batch_size,
                                                  shuffle=False)

        val_feats, val_labels, val_ids = val_data
        val_set = torchutils.SequenceDataset(val_feats,
                                             val_labels,
                                             device=device,
                                             labels_dtype=labels_dtype,
                                             seq_ids=val_ids,
                                             transpose_data=True)
        val_loader = torch.utils.data.DataLoader(val_set,
                                                 batch_size=batch_size,
                                                 shuffle=True)

        logger.info(
            f'CV fold {cv_index + 1} / {len(cv_folds)}: {len(trial_ids)} total '
            f'({len(train_ids)} train, {len(val_ids)} val, {len(test_ids)} test)'
        )

        input_dim = train_set.num_obsv_dims
        output_dim = train_set.num_label_types
        if model_name == 'linear':
            model = torchutils.LinearClassifier(
                input_dim, output_dim, **model_params).to(device=device)
        elif model_name == 'conv':
            model = ConvClassifier(input_dim, output_dim,
                                   **model_params).to(device=device)
        elif model_name == 'TCN':
            model = TcnClassifier(input_dim, output_dim, **model_params)
        else:
            raise AssertionError()

        train_epoch_log = collections.defaultdict(list)
        val_epoch_log = collections.defaultdict(list)
        metric_dict = {
            'Avg Loss': metrics.AverageLoss(),
            'Accuracy': metrics.Accuracy(),
            'Precision': metrics.Precision(),
            'Recall': metrics.Recall(),
            'F1': metrics.Fmeasure()
        }

        optimizer_ft = torch.optim.Adam(model.parameters(),
                                        lr=learning_rate,
                                        betas=(0.9, 0.999),
                                        eps=1e-08,
                                        weight_decay=0,
                                        amsgrad=False)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft,
                                                       step_size=1,
                                                       gamma=1.00)

        model, last_model_wts = torchutils.trainModel(
            model,
            criterion,
            optimizer_ft,
            lr_scheduler,
            train_loader,
            val_loader,
            device=device,
            metrics=metric_dict,
            train_epoch_log=train_epoch_log,
            val_epoch_log=val_epoch_log,
            **train_params)

        # Test model
        metric_dict = {
            'Avg Loss': metrics.AverageLoss(),
            'Accuracy': metrics.Accuracy(),
            'Precision': metrics.Precision(),
            'Recall': metrics.Recall(),
            'F1': metrics.Fmeasure()
        }
        test_io_history = torchutils.predictSamples(
            model.to(device=device),
            test_loader,
            criterion=criterion,
            device=device,
            metrics=metric_dict,
            data_labeled=True,
            update_model=False,
            seq_as_batch=train_params['seq_as_batch'],
            return_io_history=True)
        if independent_signals:
            test_io_history = tuple(joinSeqs(test_io_history))

        metric_str = '  '.join(str(m) for m in metric_dict.values())
        logger.info('[TST]  ' + metric_str)

        d = {k: v.value for k, v in metric_dict.items()}
        utils.writeResults(results_file, d, sweep_param_name, model_params)

        if plot_predictions:
            # imu.plot_prediction_eg(test_io_history, fig_dir, fig_type=fig_type, **viz_params)
            imu.plot_prediction_eg(test_io_history, fig_dir, **viz_params)

        def saveTrialData(pred_seq, score_seq, feat_seq, label_seq, trial_id):
            if label_mapping is not None:

                def dup_score_cols(scores):
                    num_cols = scores.shape[-1] + len(label_mapping)
                    col_idxs = torch.arange(num_cols)
                    for i, j in label_mapping.items():
                        col_idxs[i] = j
                    return scores[..., col_idxs]

                score_seq = dup_score_cols(score_seq)
            saveVariable(pred_seq.cpu().numpy(),
                         f'trial={trial_id}_pred-label-seq')
            saveVariable(score_seq.cpu().numpy(),
                         f'trial={trial_id}_score-seq')
            saveVariable(label_seq.cpu().numpy(),
                         f'trial={trial_id}_true-label-seq')

        for io in test_io_history:
            saveTrialData(*io)

        saveVariable(train_ids, f'cvfold={cv_index}_train-ids')
        saveVariable(test_ids, f'cvfold={cv_index}_test-ids')
        saveVariable(val_ids, f'cvfold={cv_index}_val-ids')
        saveVariable(train_epoch_log,
                     f'cvfold={cv_index}_{model_name}-train-epoch-log')
        saveVariable(val_epoch_log,
                     f'cvfold={cv_index}_{model_name}-val-epoch-log')
        saveVariable(metric_dict,
                     f'cvfold={cv_index}_{model_name}-metric-dict')
        saveVariable(model, f'cvfold={cv_index}_{model_name}-best')

        model.load_state_dict(last_model_wts)
        saveVariable(model, f'cvfold={cv_index}_{model_name}-last')

        torchutils.plotEpochLog(train_epoch_log,
                                subfig_size=(10, 2.5),
                                title='Training performance',
                                fn=os.path.join(
                                    fig_dir,
                                    f'cvfold={cv_index}_train-plot.png'))

        if val_epoch_log:
            torchutils.plotEpochLog(val_epoch_log,
                                    subfig_size=(10, 2.5),
                                    title='Heldout performance',
                                    fn=os.path.join(
                                        fig_dir,
                                        f'cvfold={cv_index}_val-plot.png'))

        if eval_label_mapping is not None:
            metric_dict = {
                'Avg Loss': metrics.AverageLoss(),
                'Accuracy': metrics.Accuracy(),
                'Precision': metrics.Precision(),
                'Recall': metrics.Recall(),
                'F1': metrics.Fmeasure()
            }
            test_io_history = torchutils.predictSamples(
                model.to(device=device),
                test_loader,
                criterion=criterion,
                device=device,
                metrics=metric_dict,
                data_labeled=True,
                update_model=False,
                seq_as_batch=train_params['seq_as_batch'],
                return_io_history=True,
                label_mapping=eval_label_mapping)
            if independent_signals:
                test_io_history = joinSeqs(test_io_history)
            metric_str = '  '.join(str(m) for m in metric_dict.values())
            logger.info('[TST]  ' + metric_str)
Exemplo n.º 8
0
def main(out_dir=None,
         data_dir=None,
         prefix='trial=',
         model_name=None,
         gpu_dev_id=None,
         batch_size=None,
         learning_rate=None,
         file_fn_format=None,
         label_fn_format=None,
         start_from=None,
         stop_at=None,
         model_params={},
         cv_params={},
         train_params={},
         viz_params={},
         num_disp_imgs=None,
         viz_templates=None,
         results_file=None,
         sweep_param_name=None):

    data_dir = os.path.expanduser(data_dir)
    out_dir = os.path.expanduser(out_dir)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    logger = utils.setupRootLogger(filename=os.path.join(out_dir, 'log.txt'))

    if results_file is None:
        results_file = os.path.join(out_dir, 'results.csv')
    else:
        results_file = os.path.expanduser(results_file)

    fig_dir = os.path.join(out_dir, 'figures')
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    io_dir = os.path.join(fig_dir, 'model-io')
    if not os.path.exists(io_dir):
        os.makedirs(io_dir)

    out_data_dir = os.path.join(out_dir, 'data')
    if not os.path.exists(out_data_dir):
        os.makedirs(out_data_dir)

    def saveVariable(var, var_name, to_dir=out_data_dir):
        utils.saveVariable(var, var_name, to_dir)

    # Load data
    trial_ids = utils.getUniqueIds(data_dir, prefix=prefix, to_array=True)
    vocab = utils.loadVariable('vocab', data_dir)
    saveVariable(vocab, 'vocab')

    # Define cross-validation folds
    data_loader = utils.CvDataset(trial_ids,
                                  data_dir,
                                  vocab=vocab,
                                  prefix=prefix,
                                  feature_fn_format=file_fn_format,
                                  label_fn_format=label_fn_format)
    cv_folds = utils.makeDataSplits(len(data_loader.trial_ids), **cv_params)

    device = torchutils.selectDevice(gpu_dev_id)
    labels_dtype = torch.long
    criterion = torch.nn.CrossEntropyLoss()
    metric_names = ('Loss', 'Accuracy')

    def make_dataset(fns, labels, ids, batch_mode='sample', shuffle=True):
        dataset = VideoDataset(fns,
                               labels,
                               device=device,
                               labels_dtype=labels_dtype,
                               seq_ids=ids,
                               batch_size=batch_size,
                               batch_mode=batch_mode)
        loader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             shuffle=shuffle)
        return dataset, loader

    for cv_index, cv_fold in enumerate(cv_folds):
        if start_from is not None and cv_index < start_from:
            continue

        if stop_at is not None and cv_index > stop_at:
            break

        train_data, val_data, test_data = data_loader.getFold(cv_fold)
        train_set, train_loader = make_dataset(*train_data,
                                               batch_mode='flatten',
                                               shuffle=True)
        test_set, test_loader = make_dataset(*test_data,
                                             batch_mode='flatten',
                                             shuffle=False)
        val_set, val_loader = make_dataset(*val_data,
                                           batch_mode='flatten',
                                           shuffle=True)

        logger.info(
            f'CV fold {cv_index + 1} / {len(cv_folds)}: {len(data_loader.trial_ids)} total '
            f'({len(train_set)} train, {len(val_set)} val, {len(test_set)} test)'
        )

        model = ImageClassifier(len(vocab), **model_params)

        optimizer_ft = torch.optim.Adam(model.parameters(),
                                        lr=learning_rate,
                                        betas=(0.9, 0.999),
                                        eps=1e-08,
                                        weight_decay=0,
                                        amsgrad=False)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft,
                                                       step_size=1,
                                                       gamma=1.00)

        train_epoch_log = collections.defaultdict(list)
        val_epoch_log = collections.defaultdict(list)
        metric_dict = {name: metrics.makeMetric(name) for name in metric_names}
        model, last_model_wts = torchutils.trainModel(
            model,
            criterion,
            optimizer_ft,
            lr_scheduler,
            train_loader,
            val_loader,
            device=device,
            metrics=metric_dict,
            train_epoch_log=train_epoch_log,
            val_epoch_log=val_epoch_log,
            **train_params)

        # Test model
        metric_dict = {name: metrics.makeMetric(name) for name in metric_names}
        test_io_history = torchutils.predictSamples(
            model.to(device=device),
            test_loader,
            criterion=criterion,
            device=device,
            metrics=metric_dict,
            data_labeled=True,
            update_model=False,
            seq_as_batch=train_params['seq_as_batch'],
            return_io_history=True)
        metric_str = '  '.join(str(m) for m in metric_dict.values())
        logger.info('[TST]  ' + metric_str)

        utils.writeResults(results_file,
                           {name: m.value
                            for name, m in metric_dict.items()},
                           sweep_param_name, model_params)

        for pred_seq, score_seq, feat_seq, label_seq, batch_id in test_io_history:
            prefix = f'cvfold={cv_index}_batch={batch_id}'
            saveVariable(pred_seq.cpu().numpy(), f'{prefix}_pred-label-seq')
            saveVariable(score_seq.cpu().numpy(), f'{prefix}_score-seq')
            saveVariable(label_seq.cpu().numpy(), f'{prefix}_true-label-seq')
        saveVariable(test_set.unflatten,
                     f'cvfold={cv_index}_test-set-unflatten')
        saveVariable(model, f'cvfold={cv_index}_{model_name}-best')

        if train_epoch_log:
            torchutils.plotEpochLog(train_epoch_log,
                                    subfig_size=(10, 2.5),
                                    title='Training performance',
                                    fn=os.path.join(
                                        fig_dir,
                                        f'cvfold={cv_index}_train-plot.png'))

        if val_epoch_log:
            torchutils.plotEpochLog(val_epoch_log,
                                    subfig_size=(10, 2.5),
                                    title='Heldout performance',
                                    fn=os.path.join(
                                        fig_dir,
                                        f'cvfold={cv_index}_val-plot.png'))
Exemplo n.º 9
0
def main(out_dir=None,
         data_dir=None,
         background_data_dir=None,
         learn_bg_model=False,
         gpu_dev_id=None,
         start_from=None,
         stop_at=None,
         num_disp_imgs=None,
         depth_bg_detection_kwargs={},
         rgb_bg_detection_kwargs={}):

    out_dir = os.path.expanduser(out_dir)
    data_dir = os.path.expanduser(data_dir)
    background_data_dir = os.path.expanduser(background_data_dir)

    logger = utils.setupRootLogger(filename=os.path.join(out_dir, 'log.txt'))

    fig_dir = os.path.join(out_dir, 'figures')
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    out_data_dir = os.path.join(out_dir, 'data')
    if not os.path.exists(out_data_dir):
        os.makedirs(out_data_dir)

    def loadFromDir(var_name, dir_name):
        return joblib.load(os.path.join(dir_name, f"{var_name}.pkl"))

    def saveToWorkingDir(var, var_name):
        joblib.dump(var, os.path.join(out_data_dir, f"{var_name}.pkl"))

    trial_ids = utils.getUniqueIds(data_dir, prefix='trial=', to_array=True)

    device = torchutils.selectDevice(gpu_dev_id)

    camera_pose = render.camera_pose
    camera_params = render.intrinsic_matrix

    for seq_idx, trial_id in enumerate(trial_ids):

        if start_from is not None and seq_idx < start_from:
            continue

        if stop_at is not None and seq_idx > stop_at:
            break

        trial_str = f"trial={trial_id}"

        logger.info(
            f"Processing video {seq_idx + 1} / {len(trial_ids)}  (trial {trial_id})"
        )

        logger.info("  Loading data...")
        try:
            rgb_frame_seq = loadFromDir(f"{trial_str}_rgb-frame-seq", data_dir)
            depth_frame_seq = loadFromDir(f"{trial_str}_depth-frame-seq",
                                          data_dir)
            rgb_train = loadFromDir(
                f"{trial_str}_rgb-frame-seq-before-first-touch",
                background_data_dir)
            depth_train = loadFromDir(
                f"{trial_str}_depth-frame-seq-before-first-touch",
                background_data_dir)

            if isinstance(depth_train, tuple) and isinstance(
                    depth_frame_seq, tuple):
                logger.info("  Skipping video: depth frames missing")
                continue

            rgb_frame_seq = np.stack(tuple(
                skimage.img_as_float(f) for f in rgb_frame_seq),
                                     axis=0)
            rgb_train = np.stack(tuple(
                skimage.img_as_float(f) for f in rgb_train),
                                 axis=0)

        except FileNotFoundError as e:
            logger.info(e)
            continue

        logger.info("  Removing background...")

        try:
            bg_mask_depth_train = loadFromDir(
                f'{trial_str}_bg-mask-depth-train', out_data_dir)
            bg_mask_seq_depth = loadFromDir(f'{trial_str}_bg-mask-seq-depth',
                                            out_data_dir)
        except FileNotFoundError:
            bg_model_depth, bg_mask_depth_train = fitBackgroundDepth(
                depth_train,
                camera_params=camera_params,
                camera_pose=camera_pose,
                **depth_bg_detection_kwargs)

            bg_mask_seq_depth = detectBackgroundDepth(
                bg_model_depth,
                depth_frame_seq,
                camera_params=camera_params,
                camera_pose=camera_pose,
                **depth_bg_detection_kwargs)

            __, bg_model_depth_image, __ = render.renderPlane(
                bg_model_depth,
                camera_pose=camera_pose,
                camera_params=camera_params)
            imageprocessing.displayImage(
                bg_model_depth_image,
                file_path=os.path.join(fig_dir,
                                       f'{trial_str}_bg-image-depth.png'))

        bg_model_rgb = fitBackgroundRgb(
            np.vstack((rgb_train, rgb_frame_seq)),
            np.vstack((bg_mask_depth_train, bg_mask_seq_depth)))

        if learn_bg_model:
            model = RgbBackgroundModel(bg_model_rgb,
                                       update_bg=True,
                                       device=device)
            losses, metrics = model.fit(np.vstack((rgb_train, rgb_frame_seq)),
                                        np.vstack((bg_mask_depth_train,
                                                   bg_mask_seq_depth)),
                                        num_epochs=100)

            outputs = model.forward(
                torch.tensor(rgb_frame_seq, dtype=torch.float,
                             device=device).permute(0, -1, 1, 2))
            bg_mask_seq_rgb = model.predict(outputs).cpu().numpy().squeeze()
            plot_dict = {'Loss': losses, 'Accuracy': metrics},
        else:

            def f1(preds, targets):
                true_positives = np.sum((targets == 1) * (preds == 1))
                false_positives = np.sum((targets == 0) * (preds == 1))
                false_negatives = np.sum((targets == 1) * (preds == 0))

                precision = true_positives / (true_positives + false_positives)
                recall = true_positives / (true_positives + false_negatives)

                f1 = 2 * (precision * recall) / (precision + recall)
                return f1

            def acc(preds, targets):
                matches = preds == targets
                return matches.mean()

            bg_dists = np.linalg.norm(rgb_frame_seq - bg_model_rgb[None, ...],
                                      axis=-1)
            thresh_vals = np.linspace(0, 1, num=50)
            scores = np.array(
                [acc(bg_dists < t, bg_mask_seq_depth) for t in thresh_vals])
            best_index = scores.argmax()
            best_thresh = thresh_vals[best_index]
            bg_mask_seq_rgb = bg_dists < best_thresh
            plot_dict = {'Accuracy': scores}

        torchutils.plotEpochLog(plot_dict,
                                subfig_size=(10, 2.5),
                                title='Training performance',
                                fn=os.path.join(fig_dir,
                                                f'{trial_str}_train-plot.png'))

        logger.info("  Saving output...")
        saveToWorkingDir(bg_mask_depth_train.astype(bool),
                         f'{trial_str}_bg-mask-depth-train')
        saveToWorkingDir(bg_mask_seq_depth.astype(bool),
                         f'{trial_str}_bg-mask-seq-depth')
        saveToWorkingDir(bg_mask_seq_rgb.astype(bool),
                         f'{trial_str}_bg-mask-seq-rgb')

        if num_disp_imgs is not None:
            if rgb_frame_seq.shape[0] > num_disp_imgs:
                idxs = np.arange(rgb_frame_seq.shape[0])
                np.random.shuffle(idxs)
                idxs = idxs[:num_disp_imgs]
            else:
                idxs = slice(None, None, None)
            imageprocessing.displayImages(*(rgb_frame_seq[idxs]),
                                          *(bg_mask_seq_rgb[idxs]),
                                          *(depth_frame_seq[idxs]),
                                          *(bg_mask_seq_depth[idxs]),
                                          num_rows=4,
                                          file_path=os.path.join(
                                              fig_dir,
                                              f'{trial_str}_best-frames.png'))
            imageprocessing.displayImage(bg_model_rgb,
                                         file_path=os.path.join(
                                             fig_dir,
                                             f'{trial_str}_bg-image-rgb.png'))
Exemplo n.º 10
0
def main(out_dir=None,
         scores_dir=None,
         preprocessed_data_dir=None,
         keyframe_model_name=None,
         subsample_period=None,
         window_size=None,
         corpus_name=None,
         default_annotator=None,
         cv_scheme=None,
         max_trials_per_fold=None,
         model_name=None,
         numeric_backend=None,
         gpu_dev_id=None,
         visualize=False,
         model_config={},
         camera_params_config={}):

    out_dir = os.path.expanduser(out_dir)
    scores_dir = os.path.expanduser(scores_dir)
    preprocessed_data_dir = os.path.expanduser(preprocessed_data_dir)

    m.set_backend('numpy')

    def loadFromWorkingDir(var_name):
        return joblib.load(os.path.join(scores_dir, f"{var_name}.pkl"))

    def saveToWorkingDir(var, var_name):
        joblib.dump(var, os.path.join(out_dir, f"{var_name}.pkl"))

    # Load camera parameters from external file and add them to model config kwargs
    model_config['init_kwargs'].update(
        render.loadCameraParams(**camera_params_config, as_dict=True))

    trial_ids = joblib.load(
        os.path.join(preprocessed_data_dir, 'trial_ids.pkl'))

    corpus = duplocorpus.DuploCorpus(corpus_name)
    assembly_seqs = tuple(
        labels.parseLabelSeq(
            corpus.readLabels(trial_id, default_annotator)[0])
        for trial_id in trial_ids)

    logger.info(f"Selecting keyframes...")
    keyframe_idx_seqs = []
    rgb_keyframe_seqs = []
    depth_keyframe_seqs = []
    seg_keyframe_seqs = []
    background_keyframe_seqs = []
    assembly_keyframe_seqs = []
    for seq_idx, trial_id in enumerate(trial_ids):
        trial_str = f"trial-{trial_id}"
        rgb_frame_seq = loadFromWorkingDir(f'{trial_str}_rgb-frame-seq')
        depth_frame_seq = loadFromWorkingDir(f'{trial_str}_depth-frame-seq')
        segment_seq = loadFromWorkingDir(f'{trial_str}_segment-seq')
        frame_scores = loadFromWorkingDir(f'{trial_str}_frame-scores')
        background_plane_seq = loadFromWorkingDir(
            f'{trial_str}_background-plane-seq')

        assembly_seq = assembly_seqs[seq_idx]
        # FIXME: Get the real frame index numbers instead of approximating
        assembly_seq[-1].end_idx = len(rgb_frame_seq) * subsample_period

        keyframe_idxs = videoprocessing.selectSegmentKeyframes(
            frame_scores, score_thresh=0, prepend_first=True)

        selectKeyframes = functools.partial(utils.select, keyframe_idxs)
        rgb_keyframe_seq = selectKeyframes(rgb_frame_seq)
        depth_keyframe_seq = selectKeyframes(depth_frame_seq)
        seg_keyframe_seq = selectKeyframes(segment_seq)
        background_keyframe_seq = selectKeyframes(background_plane_seq)

        # FIXME: Get the real frame index numbers instead of approximating
        keyframe_idxs_orig = keyframe_idxs * subsample_period
        assembly_keyframe_seq = labels.resampleStateSeq(
            keyframe_idxs_orig, assembly_seq)

        # Store all keyframe sequences in memory
        keyframe_idx_seqs.append(keyframe_idxs)
        rgb_keyframe_seqs.append(rgb_keyframe_seq)
        depth_keyframe_seqs.append(depth_keyframe_seq)
        seg_keyframe_seqs.append(seg_keyframe_seq)
        background_keyframe_seqs.append(background_keyframe_seq)
        assembly_keyframe_seqs.append(assembly_keyframe_seq)

    # Split into train and test sets
    if cv_scheme == 'leave one out':
        num_seqs = len(trial_ids)
        cv_folds = []
        for i in range(num_seqs):
            test_fold = (i, )
            train_fold = tuple(range(0, i)) + tuple(range(i + 1, num_seqs))
            cv_folds.append((train_fold, test_fold))
    elif cv_scheme == 'train on child':
        child_corpus = duplocorpus.DuploCorpus('child')
        child_trial_ids = utils.loadVariable('trial_ids',
                                             'preprocess-all-data', 'child')
        child_assembly_seqs = [
            labels.parseLabelSeq(
                child_corpus.readLabels(trial_id, 'Cathryn')[0])
            for trial_id in child_trial_ids
        ]
        num_easy = len(assembly_keyframe_seqs)
        num_child = len(child_assembly_seqs)
        cv_folds = [(tuple(range(num_easy, num_easy + num_child)),
                     tuple(range(num_easy)))]
        assembly_keyframe_seqs = assembly_keyframe_seqs + child_assembly_seqs

    rgb_keyframe_seqs = tuple(
        tuple(
            imageprocessing.saturateImage(rgb_image,
                                          background_mask=segment_image == 0)
            for rgb_image, segment_image in zip(rgb_frame_seq, seg_frame_seq))
        for rgb_frame_seq, seg_frame_seq in zip(rgb_keyframe_seqs,
                                                seg_keyframe_seqs))

    depth_keyframe_seqs = tuple(
        tuple(depth_image.astype(float) for depth_image in depth_frame_seq)
        for depth_frame_seq in depth_keyframe_seqs)

    device = torchutils.selectDevice(gpu_dev_id)
    m.set_backend('torch')
    m.set_default_device(device)

    assembly_keyframe_seqs = tuple(
        tuple(a.to(device=device, in_place=False) for a in seq)
        for seq in assembly_keyframe_seqs)
    assembly_seqs = tuple(
        tuple(a.to(device=device, in_place=False) for a in seq)
        for seq in assembly_seqs)

    rgb_keyframe_seqs = tuple(
        tuple(m.np.array(frame, dtype=torch.float) for frame in rgb_frame_seq)
        for rgb_frame_seq in rgb_keyframe_seqs)
    depth_keyframe_seqs = tuple(
        tuple(
            m.np.array(frame, dtype=torch.float) for frame in depth_frame_seq)
        for depth_frame_seq in depth_keyframe_seqs)
    seg_keyframe_seqs = tuple(
        tuple(m.np.array(frame, dtype=torch.int) for frame in seg_frame_seq)
        for seg_frame_seq in seg_keyframe_seqs)

    num_cv_folds = len(cv_folds)
    saveToWorkingDir(cv_folds, f'cv-folds')
    for fold_index, (train_idxs, test_idxs) in enumerate(cv_folds):
        logger.info(f"CV FOLD {fold_index + 1} / {num_cv_folds}")

        # Initialize and train model
        utils.validateCvFold(train_idxs, test_idxs)
        selectTrain = functools.partial(utils.select, train_idxs)
        train_assembly_seqs = selectTrain(assembly_keyframe_seqs)
        model = getattr(models, model_name)(**model_config['init_kwargs'])
        logger.info(
            f"  Training {model_name} on {len(train_idxs)} sequences...")
        model.fit(train_assembly_seqs, **model_config['fit_kwargs'])
        logger.info(
            f'    Model trained on {model.num_states} unique assembly states')
        # saveToWorkingDir(model, f'model-fold{fold_index}')

        # Decode on the test set
        selectTest = functools.partial(utils.select, test_idxs)
        test_trial_ids = selectTest(trial_ids)
        test_rgb_keyframe_seqs = selectTest(rgb_keyframe_seqs)
        test_depth_keyframe_seqs = selectTest(depth_keyframe_seqs)
        test_seg_keyframe_seqs = selectTest(seg_keyframe_seqs)
        test_background_keyframe_seqs = selectTest(background_keyframe_seqs)
        test_assembly_keyframe_seqs = selectTest(assembly_keyframe_seqs)
        test_assembly_seqs = selectTest(assembly_seqs)

        logger.info(f"  Testing model on {len(test_idxs)} sequences...")
        for i, trial_id in enumerate(test_trial_ids):
            if max_trials_per_fold is not None and i >= max_trials_per_fold:
                break

            rgb_frame_seq = test_rgb_keyframe_seqs[i]
            depth_frame_seq = test_depth_keyframe_seqs[i]
            seg_frame_seq = test_seg_keyframe_seqs[i]
            background_plane_seq = test_background_keyframe_seqs[i]
            true_assembly_seq = test_assembly_keyframe_seqs[i]
            true_assembly_seq_orig = test_assembly_seqs[i]

            rgb_background_seq, depth_background_seq = utils.batchProcess(
                model.renderPlane, background_plane_seq, unzip=True)

            logger.info(f'    Decoding video {trial_id}...')
            start_time = time.process_time()
            out = model.predictSeq(rgb_frame_seq, depth_frame_seq,
                                   seg_frame_seq, rgb_background_seq,
                                   depth_background_seq,
                                   **model_config['decode_kwargs'])
            pred_assembly_seq, pred_idx_seq, max_log_probs, log_likelihoods, poses_seq = out
            end_time = time.process_time()
            logger.info(utils.makeProcessTimeStr(end_time - start_time))

            num_correct, num_total = metrics.numberCorrect(
                true_assembly_seq, pred_assembly_seq)
            logger.info(f'    ACCURACY: {num_correct} / {num_total}')
            num_correct, num_total = metrics.numberCorrect(
                true_assembly_seq, pred_assembly_seq, ignore_empty_true=True)
            logger.info(f'    RECALL: {num_correct} / {num_total}')
            num_correct, num_total = metrics.numberCorrect(
                true_assembly_seq, pred_assembly_seq, ignore_empty_pred=True)
            logger.info(f'    PRECISION: {num_correct} / {num_total}')

            # Save intermediate results
            logger.info(f"Saving output...")
            saveToWorkingDir(segment_seq, f'segment_seq-{trial_id}')
            saveToWorkingDir(true_assembly_seq_orig,
                             f'true_state_seq_orig-{trial_id}')
            saveToWorkingDir(true_assembly_seq, f'true_state_seq-{trial_id}')
            saveToWorkingDir(pred_assembly_seq, f'pred_state_seq-{trial_id}')
            saveToWorkingDir(poses_seq, f'poses_seq-{trial_id}')
            saveToWorkingDir(background_plane_seq,
                             f'background_plane_seq-{trial_id}')
            saveToWorkingDir(max_log_probs, f'max_log_probs-{trial_id}')
            saveToWorkingDir(log_likelihoods, f'log_likelihoods-{trial_id}')

            # Save figures
            if visualize:
                rgb_rendered_seq, depth_rendered_seq, label_rendered_seq = utils.batchProcess(
                    model.renderScene,
                    pred_assembly_seq,
                    poses_seq,
                    rgb_background_seq,
                    depth_background_seq,
                    unzip=True,
                    static_kwargs={'as_numpy': True})
                if utils.in_ipython_console():
                    file_path = None
                else:
                    trial_str = f"trial-{trial_id}"
                    file_path = os.path.join(out_dir,
                                             f'{trial_str}_best-frames.png')
                rgb_frame_seq = tuple(img.cpu().numpy()
                                      for img in rgb_frame_seq)
                imageprocessing.displayImages(*rgb_frame_seq,
                                              *rgb_rendered_seq,
                                              num_rows=2,
                                              file_path=file_path)
Exemplo n.º 11
0
def main(
        out_dir=None, modalities=['rgb', 'imu'], gpu_dev_id=None, plot_io=None,
        rgb_data_dir=None, rgb_attributes_dir=None, imu_data_dir=None, imu_attributes_dir=None):

    out_dir = os.path.expanduser(out_dir)
    rgb_data_dir = os.path.expanduser(rgb_data_dir)
    rgb_attributes_dir = os.path.expanduser(rgb_attributes_dir)
    imu_data_dir = os.path.expanduser(imu_data_dir)
    imu_attributes_dir = os.path.expanduser(imu_attributes_dir)

    logger = utils.setupRootLogger(filename=os.path.join(out_dir, 'log.txt'))

    fig_dir = os.path.join(out_dir, 'figures')
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    out_data_dir = os.path.join(out_dir, 'data')
    if not os.path.exists(out_data_dir):
        os.makedirs(out_data_dir)

    # Load data
    if modalities == ['rgb']:
        trial_ids = utils.getUniqueIds(rgb_data_dir, prefix='trial=', to_array=True)
        logger.info(f"Processing {len(trial_ids)} videos")
    else:
        rgb_trial_ids = utils.getUniqueIds(rgb_data_dir, prefix='trial=', to_array=True)
        imu_trial_ids = utils.getUniqueIds(imu_data_dir, prefix='trial=', to_array=True)
        trial_ids = np.array(sorted(set(rgb_trial_ids.tolist()) & set(imu_trial_ids.tolist())))
        logger.info(
            f"Processing {len(trial_ids)} videos common to "
            f"RGB ({len(rgb_trial_ids)} total) and IMU ({len(imu_trial_ids)} total)"
        )

    device = torchutils.selectDevice(gpu_dev_id)
    dataset = FusionDataset(
        trial_ids, rgb_attributes_dir, rgb_data_dir, imu_attributes_dir, imu_data_dir,
        device=device, modalities=modalities,
    )
    utils.saveMetadata(dataset.metadata, out_data_dir)
    utils.saveVariable(dataset.vocab, 'vocab', out_data_dir)

    for i, trial_id in enumerate(trial_ids):
        logger.info(f"Processing sequence {trial_id}...")

        true_label_seq = dataset.loadTargets(trial_id)
        attribute_feats = dataset.loadInputs(trial_id)

        # (Process the samples here if we need to)

        attribute_feats = attribute_feats.cpu().numpy()
        true_label_seq = true_label_seq.cpu().numpy()

        trial_prefix = f"trial={trial_id}"
        utils.saveVariable(attribute_feats, f'{trial_prefix}_feature-seq', out_data_dir)
        utils.saveVariable(true_label_seq, f'{trial_prefix}_label-seq', out_data_dir)

        if plot_io:
            fn = os.path.join(fig_dir, f'{trial_prefix}.png')
            utils.plot_array(
                attribute_feats.T,
                (true_label_seq,),
                ('gt',),
                fn=fn
            )
Exemplo n.º 12
0
def main(out_dir=None,
         data_dir=None,
         model_name=None,
         part_symmetries=None,
         gpu_dev_id=None,
         batch_size=None,
         learning_rate=None,
         model_params={},
         cv_params={},
         train_params={},
         viz_params={},
         plot_predictions=None,
         results_file=None,
         sweep_param_name=None):

    if part_symmetries is None:
        part_symmetries = {
            'beam_side': ('backbeam_hole_1', 'backbeam_hole_2',
                          'frontbeam_hole_1', 'frontbeam_hole_2'),
            'beam_top': ('backbeam_hole_3', 'frontbeam_hole_3'),
            'backrest': ('backrest_hole_1', 'backrest_hole_2')
        }

    data_dir = os.path.expanduser(data_dir)
    out_dir = os.path.expanduser(out_dir)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    logger = utils.setupRootLogger(filename=os.path.join(out_dir, 'log.txt'))

    if results_file is None:
        results_file = os.path.join(out_dir, 'results.csv')
    else:
        results_file = os.path.expanduser(results_file)

    fig_dir = os.path.join(out_dir, 'figures')
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    out_data_dir = os.path.join(out_dir, 'data')
    if not os.path.exists(out_data_dir):
        os.makedirs(out_data_dir)

    def saveVariable(var, var_name):
        joblib.dump(var, os.path.join(out_data_dir, f'{var_name}.pkl'))

    def loadAll(seq_ids, var_name, data_dir):
        def loadOne(seq_id):
            fn = os.path.join(data_dir, f'trial={seq_id}_{var_name}')
            return joblib.load(fn)

        return tuple(map(loadOne, seq_ids))

    # Load vocab
    with open(os.path.join(data_dir, "part-vocab.yaml"), 'rt') as f:
        link_vocab = yaml.safe_load(f)
    assembly_vocab = joblib.load(os.path.join(data_dir, 'assembly-vocab.pkl'))

    # Load data
    trial_ids = utils.getUniqueIds(data_dir, prefix='trial=')
    feature_seqs = loadAll(trial_ids, 'feature-seq.pkl', data_dir)
    label_seqs = loadAll(trial_ids, 'label-seq.pkl', data_dir)

    if part_symmetries:
        # Construct equivalence classes from vocab
        eq_classes, assembly_eq_classes, eq_class_vocab = makeEqClasses(
            assembly_vocab, part_symmetries)
        lib_assembly.writeAssemblies(
            os.path.join(fig_dir, 'eq-class-vocab.txt'), eq_class_vocab)
        label_seqs = tuple(assembly_eq_classes[label_seq]
                           for label_seq in label_seqs)
        saveVariable(eq_class_vocab, 'assembly-vocab')
    else:
        eq_classes = None

    def impute_nan(input_seq):
        input_is_nan = np.isnan(input_seq)
        logger.info(f"{input_is_nan.sum()} NaN elements")
        input_seq[input_is_nan] = 0  # np.nanmean(input_seq)
        return input_seq

    # feature_seqs = tuple(map(impute_nan, feature_seqs))

    for trial_id, label_seq, feat_seq in zip(trial_ids, label_seqs,
                                             feature_seqs):
        saveVariable(feat_seq, f"trial={trial_id}_feature-seq")
        saveVariable(label_seq, f"trial={trial_id}_label-seq")

    device = torchutils.selectDevice(gpu_dev_id)

    # Define cross-validation folds
    dataset_size = len(trial_ids)
    cv_folds = utils.makeDataSplits(dataset_size, **cv_params)

    def getSplit(split_idxs):
        split_data = tuple(
            tuple(s[i] for i in split_idxs)
            for s in (feature_seqs, label_seqs, trial_ids))
        return split_data

    for cv_index, cv_splits in enumerate(cv_folds):
        train_data, val_data, test_data = tuple(map(getSplit, cv_splits))

        train_feats, train_labels, train_ids = train_data
        train_set = torchutils.SequenceDataset(train_feats,
                                               train_labels,
                                               device=device,
                                               labels_dtype=torch.long,
                                               seq_ids=train_ids,
                                               transpose_data=True)
        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=batch_size,
                                                   shuffle=True)

        test_feats, test_labels, test_ids = test_data
        test_set = torchutils.SequenceDataset(test_feats,
                                              test_labels,
                                              device=device,
                                              labels_dtype=torch.long,
                                              seq_ids=test_ids,
                                              transpose_data=True)
        test_loader = torch.utils.data.DataLoader(test_set,
                                                  batch_size=batch_size,
                                                  shuffle=False)

        val_feats, val_labels, val_ids = val_data
        val_set = torchutils.SequenceDataset(val_feats,
                                             val_labels,
                                             device=device,
                                             labels_dtype=torch.long,
                                             seq_ids=val_ids,
                                             transpose_data=True)
        val_loader = torch.utils.data.DataLoader(val_set,
                                                 batch_size=batch_size,
                                                 shuffle=True)

        logger.info(
            f'CV fold {cv_index + 1} / {len(cv_folds)}: {len(trial_ids)} total '
            f'({len(train_ids)} train, {len(val_ids)} val, {len(test_ids)} test)'
        )

        input_dim = train_set.num_obsv_dims
        output_dim = train_set.num_label_types
        if model_name == 'linear':
            model = torchutils.LinearClassifier(
                input_dim, output_dim, **model_params).to(device=device)
        elif model_name == 'dummy':
            model = DummyClassifier(input_dim, output_dim, **model_params)
        elif model_name == 'AssemblyClassifier':
            model = AssemblyClassifier(assembly_vocab,
                                       link_vocab,
                                       eq_classes=eq_classes,
                                       **model_params)
        else:
            raise AssertionError()

        criterion = torch.nn.CrossEntropyLoss()
        if model_name != 'dummy':
            train_epoch_log = collections.defaultdict(list)
            val_epoch_log = collections.defaultdict(list)
            metric_dict = {
                'Avg Loss': metrics.AverageLoss(),
                'Accuracy': metrics.Accuracy(),
                'Precision': metrics.Precision(),
                'Recall': metrics.Recall(),
                'F1': metrics.Fmeasure()
            }

            optimizer_ft = torch.optim.Adam(model.parameters(),
                                            lr=learning_rate,
                                            betas=(0.9, 0.999),
                                            eps=1e-08,
                                            weight_decay=0,
                                            amsgrad=False)
            lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft,
                                                           step_size=1,
                                                           gamma=1.00)

            model, last_model_wts = torchutils.trainModel(
                model,
                criterion,
                optimizer_ft,
                lr_scheduler,
                train_loader,
                val_loader,
                device=device,
                metrics=metric_dict,
                train_epoch_log=train_epoch_log,
                val_epoch_log=val_epoch_log,
                **train_params)

        logger.info(f'scale={float(model._scale)}')
        logger.info(f'alpha={float(model._alpha)}')

        # Test model
        metric_dict = {
            'Avg Loss': metrics.AverageLoss(),
            'Accuracy': metrics.Accuracy(),
            'Precision': metrics.Precision(),
            'Recall': metrics.Recall(),
            'F1': metrics.Fmeasure()
        }
        test_io_history = torchutils.predictSamples(
            model.to(device=device),
            test_loader,
            criterion=criterion,
            device=device,
            metrics=metric_dict,
            data_labeled=True,
            update_model=False,
            seq_as_batch=train_params['seq_as_batch'],
            return_io_history=True)

        metric_str = '  '.join(str(m) for m in metric_dict.values())
        logger.info('[TST]  ' + metric_str)

        d = {k: v.value for k, v in metric_dict.items()}
        utils.writeResults(results_file, d, sweep_param_name, model_params)

        if plot_predictions:
            io_fig_dir = os.path.join(fig_dir, 'model-io')
            if not os.path.exists(io_fig_dir):
                os.makedirs(io_fig_dir)

            label_names = ('gt', 'pred')
            preds, scores, inputs, gt_labels, ids = zip(*test_io_history)
            for batch in test_io_history:
                batch = tuple(
                    x.cpu().numpy() if isinstance(x, torch.Tensor) else x
                    for x in batch)
                for preds, _, inputs, gt_labels, seq_id in zip(*batch):
                    fn = os.path.join(io_fig_dir,
                                      f"trial={seq_id}_model-io.png")
                    utils.plot_array(inputs.sum(axis=-1), (gt_labels, preds),
                                     label_names,
                                     fn=fn,
                                     **viz_params)

        def saveTrialData(pred_seq, score_seq, feat_seq, label_seq, trial_id):
            saveVariable(pred_seq, f'trial={trial_id}_pred-label-seq')
            saveVariable(score_seq, f'trial={trial_id}_score-seq')
            saveVariable(label_seq, f'trial={trial_id}_true-label-seq')

        for batch in test_io_history:
            batch = tuple(x.cpu().numpy() if isinstance(x, torch.Tensor) else x
                          for x in batch)
            for io in zip(*batch):
                saveTrialData(*io)

        saveVariable(train_ids, f'cvfold={cv_index}_train-ids')
        saveVariable(test_ids, f'cvfold={cv_index}_test-ids')
        saveVariable(val_ids, f'cvfold={cv_index}_val-ids')
        saveVariable(train_epoch_log,
                     f'cvfold={cv_index}_{model_name}-train-epoch-log')
        saveVariable(val_epoch_log,
                     f'cvfold={cv_index}_{model_name}-val-epoch-log')
        saveVariable(metric_dict,
                     f'cvfold={cv_index}_{model_name}-metric-dict')
        saveVariable(model, f'cvfold={cv_index}_{model_name}-best')

        train_fig_dir = os.path.join(fig_dir, 'train-plots')
        if not os.path.exists(train_fig_dir):
            os.makedirs(train_fig_dir)

        if train_epoch_log:
            torchutils.plotEpochLog(train_epoch_log,
                                    subfig_size=(10, 2.5),
                                    title='Training performance',
                                    fn=os.path.join(
                                        train_fig_dir,
                                        f'cvfold={cv_index}_train-plot.png'))

        if val_epoch_log:
            torchutils.plotEpochLog(val_epoch_log,
                                    subfig_size=(10, 2.5),
                                    title='Heldout performance',
                                    fn=os.path.join(
                                        train_fig_dir,
                                        f'cvfold={cv_index}_val-plot.png'))
Exemplo n.º 13
0
def main(out_dir=None,
         data_dir=None,
         segs_dir=None,
         pretrained_model_dir=None,
         model_name=None,
         gpu_dev_id=None,
         batch_size=None,
         learning_rate=None,
         start_from=None,
         stop_at=None,
         model_params={},
         cv_params={},
         train_params={},
         viz_params={},
         num_disp_imgs=None,
         viz_templates=None,
         results_file=None,
         sweep_param_name=None):

    data_dir = os.path.expanduser(data_dir)
    segs_dir = os.path.expanduser(segs_dir)
    pretrained_model_dir = os.path.expanduser(pretrained_model_dir)
    out_dir = os.path.expanduser(out_dir)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    logger = utils.setupRootLogger(filename=os.path.join(out_dir, 'log.txt'))

    if results_file is None:
        results_file = os.path.join(out_dir, 'results.csv')
    else:
        results_file = os.path.expanduser(results_file)

    fig_dir = os.path.join(out_dir, 'figures')
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    io_dir = os.path.join(fig_dir, 'model-io')
    if not os.path.exists(io_dir):
        os.makedirs(io_dir)

    out_data_dir = os.path.join(out_dir, 'data')
    if not os.path.exists(out_data_dir):
        os.makedirs(out_data_dir)

    def saveVariable(var, var_name, to_dir=out_data_dir):
        utils.saveVariable(var, var_name, to_dir)

    # Load data
    trial_ids = utils.getUniqueIds(data_dir, prefix='trial=', to_array=True)
    vocab = utils.loadVariable('vocab', pretrained_model_dir)
    parts_vocab = utils.loadVariable('parts-vocab', pretrained_model_dir)
    edge_labels = utils.loadVariable('part-labels', pretrained_model_dir)
    saveVariable(vocab, 'vocab')
    saveVariable(parts_vocab, 'parts-vocab')
    saveVariable(edge_labels, 'part-labels')

    # Define cross-validation folds
    data_loader = VideoLoader(trial_ids,
                              data_dir,
                              segs_dir,
                              vocab=vocab,
                              label_fn_format='assembly-seq')
    cv_folds = utils.makeDataSplits(len(data_loader.trial_ids), **cv_params)

    Dataset = sim2real.BlocksConnectionDataset
    device = torchutils.selectDevice(gpu_dev_id)
    label_dtype = torch.long
    labels_dtype = torch.long  # FIXME
    criterion = torch.nn.CrossEntropyLoss()

    def make_dataset(labels, ids, batch_mode='sample', shuffle=True):
        dataset = Dataset(vocab,
                          edge_labels,
                          label_dtype,
                          data_loader.loadData,
                          labels,
                          device=device,
                          labels_dtype=labels_dtype,
                          seq_ids=ids,
                          batch_size=batch_size,
                          batch_mode=batch_mode)
        loader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             shuffle=shuffle)
        return dataset, loader

    for cv_index, cv_fold in enumerate(cv_folds):
        if start_from is not None and cv_index < start_from:
            continue

        if stop_at is not None and cv_index > stop_at:
            break

        train_data, val_data, test_data = data_loader.getFold(cv_fold)
        train_set, train_loader = make_dataset(*train_data,
                                               batch_mode='sample',
                                               shuffle=True)
        test_set, test_loader = make_dataset(*test_data,
                                             batch_mode='flatten',
                                             shuffle=False)
        val_set, val_loader = make_dataset(*val_data,
                                           batch_mode='sample',
                                           shuffle=True)

        logger.info(
            f'CV fold {cv_index + 1} / {len(cv_folds)}: {len(data_loader.trial_ids)} total '
            f'({len(train_set)} train, {len(val_set)} val, {len(test_set)} test)'
        )

        logger.info(
            f"Class freqs (train): {np.squeeze(train_set.class_freqs)}")
        logger.info(f"Class freqs   (val): {np.squeeze(val_set.class_freqs)}")
        logger.info(f"Class freqs  (test): {np.squeeze(test_set.class_freqs)}")

        if model_name == 'template':
            model = sim2real.AssemblyClassifier(vocab, **model_params)
        elif model_name == 'pretrained':
            pretrained_model = utils.loadVariable("cvfold=0_model-best",
                                                  pretrained_model_dir)
            model = sim2real.SceneClassifier(pretrained_model, **model_params)
            metric_names = ('Loss', 'Accuracy', 'Precision', 'Recall', 'F1')
            criterion = torch.nn.CrossEntropyLoss()
            # criterion = torchutils.BootstrappedCriterion(
            #     0.25, base_criterion=torch.nn.functional.cross_entropy,
            # )
        else:
            raise AssertionError()

        optimizer_ft = torch.optim.Adam(model.parameters(),
                                        lr=learning_rate,
                                        betas=(0.9, 0.999),
                                        eps=1e-08,
                                        weight_decay=0,
                                        amsgrad=False)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft,
                                                       step_size=1,
                                                       gamma=1.00)

        train_epoch_log = collections.defaultdict(list)
        val_epoch_log = collections.defaultdict(list)
        metric_dict = {name: metrics.makeMetric(name) for name in metric_names}
        model, last_model_wts = torchutils.trainModel(
            model,
            criterion,
            optimizer_ft,
            lr_scheduler,
            train_loader,
            val_loader,
            device=device,
            metrics=metric_dict,
            train_epoch_log=train_epoch_log,
            val_epoch_log=val_epoch_log,
            **train_params)

        # Test model
        metric_dict = {name: metrics.makeMetric(name) for name in metric_names}
        test_io_history = torchutils.predictSamples(
            model.to(device=device),
            test_loader,
            criterion=criterion,
            device=device,
            metrics=metric_dict,
            data_labeled=True,
            update_model=False,
            seq_as_batch=train_params['seq_as_batch'],
            return_io_history=True)
        metric_str = '  '.join(str(m) for m in metric_dict.values())
        logger.info('[TST]  ' + metric_str)

        utils.writeResults(results_file,
                           {name: m.value
                            for name, m in metric_dict.items()},
                           sweep_param_name, model_params)

        for pred_seq, score_seq, feat_seq, label_seq, batch_id in test_io_history:
            prefix = f'cvfold={cv_index}_batch={batch_id}'
            saveVariable(pred_seq.cpu().numpy(), f'{prefix}_pred-label-seq')
            saveVariable(score_seq.cpu().numpy(), f'{prefix}_score-seq')
            saveVariable(label_seq.cpu().numpy(), f'{prefix}_true-label-seq')
        saveVariable(test_set.unflatten,
                     f'cvfold={cv_index}_test-set-unflatten')
        saveVariable(model, f'cvfold={cv_index}_{model_name}-best')

        if train_epoch_log:
            torchutils.plotEpochLog(train_epoch_log,
                                    subfig_size=(10, 2.5),
                                    title='Training performance',
                                    fn=os.path.join(
                                        fig_dir,
                                        f'cvfold={cv_index}_train-plot.png'))

        if val_epoch_log:
            torchutils.plotEpochLog(val_epoch_log,
                                    subfig_size=(10, 2.5),
                                    title='Heldout performance',
                                    fn=os.path.join(
                                        fig_dir,
                                        f'cvfold={cv_index}_val-plot.png'))

        if model_name == 'pretrained' and num_disp_imgs is not None:
            cvfold_dir = os.path.join(io_dir, f'cvfold={cv_index}')
            if not os.path.exists(cvfold_dir):
                os.makedirs(cvfold_dir)
            model.plotBatches(test_io_history,
                              cvfold_dir,
                              images_per_fig=num_disp_imgs,
                              dataset=test_set)

        if model_name == 'template' and num_disp_imgs is not None:
            io_dir = os.path.join(fig_dir, 'model-io')
            if not os.path.exists(io_dir):
                os.makedirs(io_dir)
            plot_topk(model, test_io_history, num_disp_imgs,
                      os.path.join(io_dir, f"cvfold={cv_index}.png"))

        if viz_templates:
            sim2real.viz_model_params(model, templates_dir=None)
Exemplo n.º 14
0
def main(out_dir=None,
         data_dir=None,
         gpu_dev_id=None,
         batch_size=None,
         start_from=None,
         stop_at=None,
         num_disp_imgs=None):

    out_dir = os.path.expanduser(out_dir)
    data_dir = os.path.expanduser(data_dir)

    logger = utils.setupRootLogger(filename=os.path.join(out_dir, 'log.txt'))

    fig_dir = os.path.join(out_dir, 'figures')
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    out_data_dir = os.path.join(out_dir, 'data')
    if not os.path.exists(out_data_dir):
        os.makedirs(out_data_dir)

    def loadFromDir(var_name, dir_name):
        return joblib.load(os.path.join(dir_name, f"{var_name}.pkl"))

    def saveToWorkingDir(var, var_name):
        joblib.dump(var, os.path.join(out_data_dir, f"{var_name}.pkl"))

    trial_ids = utils.getUniqueIds(data_dir, prefix='trial=', to_array=True)

    device = torchutils.selectDevice(gpu_dev_id)

    for seq_idx, trial_id in enumerate(trial_ids):

        if start_from is not None and seq_idx < start_from:
            continue

        if stop_at is not None and seq_idx > stop_at:
            break

        trial_str = f"trial={trial_id}"

        logger.info(
            f"Processing video {seq_idx + 1} / {len(trial_ids)}  (trial {trial_id})"
        )

        logger.info("  Loading data...")
        try:
            rgb_frame_seq = loadFromDir(f"{trial_str}_rgb-frame-seq", data_dir)
            rgb_frame_seq = np.stack(tuple(
                skimage.img_as_float(f) for f in rgb_frame_seq),
                                     axis=0)
        except FileNotFoundError as e:
            logger.info(e)
            continue

        logger.info("  Detecting objects...")
        model = torchvision.models.detection.maskrcnn_resnet50_fpn(
            pretrained=True)
        model = model.to(device=device)
        model.device = device
        model.eval()

        inputs = np.moveaxis(rgb_frame_seq, 3, 1)

        if batch_size is None:
            batch_size = inputs.shape[0]

        def detectBatch(batch_index):
            start = batch_size * batch_index
            end = start + batch_size
            in_batch = torch.tensor(inputs[start:end], dtype=torch.float)
            out_batches = detectCategories(model, in_batch)
            return tuple(batch.numpy().squeeze(axis=1)
                         for batch in out_batches)

        num_batches = math.ceil(inputs.shape[0] / batch_size)
        person_mask_seq, bg_mask_seq = map(
            np.vstack, zip(*(detectBatch(i) for i in range(num_batches))))
        person_mask_seq = person_mask_seq.astype(bool)

        logger.info("  Saving output...")
        saveToWorkingDir(person_mask_seq, f'{trial_str}_person-mask-seq')

        if num_disp_imgs is not None:
            if rgb_frame_seq.shape[0] > num_disp_imgs:
                idxs = np.arange(rgb_frame_seq.shape[0])
                np.random.shuffle(idxs)
                idxs = idxs[:num_disp_imgs]
            else:
                idxs = slice(None, None, None)
            imageprocessing.displayImages(*(rgb_frame_seq[idxs]),
                                          *(person_mask_seq[idxs]),
                                          *(bg_mask_seq[idxs]),
                                          num_rows=3,
                                          file_path=os.path.join(
                                              fig_dir,
                                              f'{trial_str}_best-frames.png'))