def main(
        out_dir=None, data_dir=None, use_vid_ids_from=None,
        output_data=None, magnitude_centering=None, resting_from_gt=None,
        remove_before_first_touch=None, include_signals=None, fig_type=None):

    data_dir = os.path.expanduser(data_dir)
    out_dir = os.path.expanduser(out_dir)

    logger = utils.setupRootLogger(filename=os.path.join(out_dir, 'log.txt'))

    logger.info(f"Reading from: {data_dir}")
    logger.info(f"Writing to: {out_dir}")

    fig_dir = os.path.join(out_dir, 'figures')
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    out_data_dir = os.path.join(out_dir, 'data')
    if not os.path.exists(out_data_dir):
        os.makedirs(out_data_dir)

    def loadAll(seq_ids, var_name, from_dir=data_dir, prefix='trial='):
        all_data = tuple(
            utils.loadVariable(f"{prefix}{seq_id}_{var_name}", from_dir)
            for seq_id in seq_ids
        )
        return all_data

    def saveVariable(var, var_name, to_dir=out_data_dir):
        utils.saveVariable(var, var_name, to_dir)

    if fig_type is None:
        fig_type = 'multi'

    # Load data
    if use_vid_ids_from is None:
        trial_ids = utils.getUniqueIds(data_dir, prefix='trial=', to_array=True)
    else:
        use_vid_ids_from = os.path.expanduser(use_vid_ids_from)
        trial_ids = utils.getUniqueIds(use_vid_ids_from, prefix='trial-', to_array=True)

    accel_seqs = loadAll(trial_ids, 'accel-samples.pkl')
    gyro_seqs = loadAll(trial_ids, 'gyro-samples.pkl')
    action_seqs = loadAll(trial_ids, 'action-seq.pkl')
    rgb_timestamp_seqs = loadAll(trial_ids, 'rgb-frame-timestamp-seq.pkl')

    def validate_imu(seqs):
        def is_valid(d):
            return not any(np.isnan(x).any() for x in d.values())
        return np.array([is_valid(d) for d in seqs])

    imu_is_valid = validate_imu(accel_seqs) & validate_imu(gyro_seqs)
    logger.info(
        f"Ignoring {(~imu_is_valid).sum()} IMU sequences with NaN-valued samples "
        f"(of {len(imu_is_valid)} total)"
    )

    def chooseValid(seq):
        return tuple(x for x, is_valid in zip(seq, imu_is_valid) if is_valid)
    trial_ids = np.array(list(chooseValid(trial_ids)))
    accel_seqs = chooseValid(accel_seqs)
    gyro_seqs = chooseValid(gyro_seqs)
    action_seqs = chooseValid(action_seqs)
    rgb_timestamp_seqs = chooseValid(rgb_timestamp_seqs)

    vocab = []
    metadata = utils.loadMetadata(data_dir, rows=trial_ids)
    utils.saveMetadata(metadata, out_data_dir)
    utils.saveVariable(vocab, 'vocab', out_data_dir)

    def norm(x):
        norm = np.linalg.norm(imu.getImuSamples(x), axis=1)[:, None]
        return norm
    accel_mag_seqs = tuple(map(lambda x: dictToArray(x, transform=norm), accel_seqs))
    gyro_mag_seqs = tuple(map(lambda x: dictToArray(x, transform=norm), gyro_seqs))

    imu_timestamp_seqs = utils.batchProcess(makeTimestamps, accel_seqs, gyro_seqs)

    if remove_before_first_touch:
        before_first_touch_seqs = utils.batchProcess(
            beforeFirstTouch, action_seqs, rgb_timestamp_seqs, imu_timestamp_seqs
        )

        num_ignored = sum(b is None for b in before_first_touch_seqs)
        logger.info(
            f"Ignoring {num_ignored} sequences without first-touch annotations "
            f"(of {len(before_first_touch_seqs)} total)"
        )
        trials_missing_first_touch = [
            i for b, i in zip(before_first_touch_seqs, trial_ids)
            if b is None
        ]
        logger.info(f"Trials without first touch: {trials_missing_first_touch}")

        def clip(signal, bool_array):
            return signal[~bool_array, ...]
        accel_mag_seqs = tuple(
            clip(signal, b) for signal, b in zip(accel_mag_seqs, before_first_touch_seqs)
            if b is not None
        )
        gyro_mag_seqs = tuple(
            clip(signal, b) for signal, b in zip(gyro_mag_seqs, before_first_touch_seqs)
            if b is not None
        )
        imu_timestamp_seqs = tuple(
            clip(signal, b) for signal, b in zip(imu_timestamp_seqs, before_first_touch_seqs)
            if b is not None
        )
        trial_ids = tuple(
            x for x, b in zip(trial_ids, before_first_touch_seqs)
            if b is not None
        )
        action_seqs = tuple(
            x for x, b in zip(action_seqs, before_first_touch_seqs)
            if b is not None
        )
        rgb_timestamp_seqs = tuple(
            x for x, b in zip(rgb_timestamp_seqs, before_first_touch_seqs)
            if b is not None
        )

    assembly_seqs = utils.batchProcess(
        parseActions,
        action_seqs, rgb_timestamp_seqs, imu_timestamp_seqs
    )

    if output_data == 'components':
        accel_feat_seqs = accel_mag_seqs
        gyro_feat_seqs = gyro_mag_seqs
        unique_components = {frozenset(): 0}
        imu_label_seqs = zip(
            *tuple(
                labels.componentLabels(*args, unique_components)
                for args in zip(action_seqs, rgb_timestamp_seqs, imu_timestamp_seqs)
            )
        )
        saveVariable(unique_components, 'unique_components')
    elif output_data == 'pairwise components':
        imu_label_seqs = utils.batchProcess(
            labels.pairwiseComponentLabels, assembly_seqs,
            static_kwargs={'lower_tri_only': True, 'include_action_labels': False}
        )
        accel_feat_seqs = tuple(map(imu.pairwiseFeats, accel_mag_seqs))
        gyro_feat_seqs = tuple(map(imu.pairwiseFeats, gyro_mag_seqs))
    else:
        raise AssertionError()

    signals = {'accel': accel_feat_seqs, 'gyro': gyro_feat_seqs}
    if include_signals is None:
        include_signals = tuple(signals.keys())
    signals = tuple(signals[key] for key in include_signals)
    imu_feature_seqs = tuple(np.stack(x, axis=-1).squeeze(axis=-1) for x in zip(*signals))

    video_seqs = tuple(zip(imu_feature_seqs, imu_label_seqs, trial_ids))
    imu.plot_prediction_eg(video_seqs, fig_dir, fig_type=fig_type, output_data=output_data)

    video_seqs = tuple(
        zip(assembly_seqs, imu_feature_seqs, imu_timestamp_seqs, imu_label_seqs, trial_ids)
    )
    for assembly_seq, feature_seq, timestamp_seq, label_seq, trial_id in video_seqs:
        id_string = f"trial={trial_id}"
        saveVariable(assembly_seq, f'{id_string}_assembly-seq')
        saveVariable(feature_seq, f'{id_string}_feature-seq')
        saveVariable(timestamp_seq, f'{id_string}_timestamp-seq')
        saveVariable(label_seq, f'{id_string}_label-seq')
Exemple #2
0
def main(out_dir=None,
         data_dir=None,
         attr_dir=None,
         model_name=None,
         gpu_dev_id=None,
         batch_size=None,
         learning_rate=None,
         model_params={},
         cv_params={},
         train_params={},
         viz_params={},
         plot_predictions=None,
         results_file=None,
         sweep_param_name=None):

    data_dir = os.path.expanduser(data_dir)
    out_dir = os.path.expanduser(out_dir)
    attr_dir = os.path.expanduser(attr_dir)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    logger = utils.setupRootLogger(filename=os.path.join(out_dir, 'log.txt'))

    if results_file is None:
        results_file = os.path.join(out_dir, f'results.csv')
    else:
        results_file = os.path.expanduser(results_file)

    fig_dir = os.path.join(out_dir, 'figures')
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    out_data_dir = os.path.join(out_dir, 'data')
    if not os.path.exists(out_data_dir):
        os.makedirs(out_data_dir)

    def loadData(seq_id):
        var_name = f"trial-{seq_id}_rgb-frame-seq"
        data = joblib.load(os.path.join(data_dir, f'{var_name}.pkl'))
        return data.swapaxes(1, 3)

    def loadLabels(seq_id):
        var_name = f"trial-{seq_id}_label-seq"
        return joblib.load(os.path.join(attr_dir, f'{var_name}.pkl'))

    def saveVariable(var, var_name):
        joblib.dump(var, os.path.join(out_data_dir, f'{var_name}.pkl'))

    # Load data
    trial_ids = utils.getUniqueIds(data_dir)
    label_seqs = tuple(map(loadLabels, trial_ids))

    device = torchutils.selectDevice(gpu_dev_id)

    # Define cross-validation folds
    dataset_size = len(trial_ids)
    cv_folds = utils.makeDataSplits(dataset_size, **cv_params)

    def getSplit(split_idxs):
        split_data = tuple(
            tuple(s[i] for i in split_idxs) for s in (label_seqs, trial_ids))
        return split_data

    for cv_index, cv_splits in enumerate(cv_folds):
        train_data, val_data, test_data = tuple(map(getSplit, cv_splits))

        criterion = torch.nn.BCEWithLogitsLoss()
        labels_dtype = torch.float

        train_labels, train_ids = train_data
        train_set = torchutils.PickledVideoDataset(loadData,
                                                   train_labels,
                                                   device=device,
                                                   labels_dtype=labels_dtype,
                                                   seq_ids=train_ids,
                                                   batch_size=batch_size)
        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=1,
                                                   shuffle=True)

        test_labels, test_ids = test_data
        test_set = torchutils.PickledVideoDataset(loadData,
                                                  test_labels,
                                                  device=device,
                                                  labels_dtype=labels_dtype,
                                                  seq_ids=test_ids,
                                                  batch_size=batch_size)
        test_loader = torch.utils.data.DataLoader(test_set,
                                                  batch_size=1,
                                                  shuffle=False)

        val_labels, val_ids = val_data
        val_set = torchutils.PickledVideoDataset(loadData,
                                                 val_labels,
                                                 device=device,
                                                 labels_dtype=labels_dtype,
                                                 seq_ids=val_ids,
                                                 batch_size=batch_size)
        val_loader = torch.utils.data.DataLoader(val_set,
                                                 batch_size=1,
                                                 shuffle=True)

        logger.info(
            f'CV fold {cv_index + 1} / {len(cv_folds)}: {len(trial_ids)} total '
            f'({len(train_ids)} train, {len(val_ids)} val, {len(test_ids)} test)'
        )

        if model_name == 'resnet':
            # input_dim = train_set.num_obsv_dims
            output_dim = train_set.num_label_types
            model = ImageClassifier(output_dim,
                                    **model_params).to(device=device)
        else:
            raise AssertionError()

        train_epoch_log = collections.defaultdict(list)
        val_epoch_log = collections.defaultdict(list)
        metric_dict = {
            'Avg Loss': metrics.AverageLoss(),
            'Accuracy': metrics.Accuracy(),
            'Precision': metrics.Precision(),
            'Recall': metrics.Recall(),
            'F1': metrics.Fmeasure()
        }

        optimizer_ft = torch.optim.Adam(model.parameters(),
                                        lr=learning_rate,
                                        betas=(0.9, 0.999),
                                        eps=1e-08,
                                        weight_decay=0,
                                        amsgrad=False)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft,
                                                       step_size=1,
                                                       gamma=1.00)

        model, last_model_wts = torchutils.trainModel(
            model,
            criterion,
            optimizer_ft,
            lr_scheduler,
            train_loader,
            val_loader,
            device=device,
            metrics=metric_dict,
            train_epoch_log=train_epoch_log,
            val_epoch_log=val_epoch_log,
            **train_params)

        # Test model
        metric_dict = {
            'Avg Loss': metrics.AverageLoss(),
            'Accuracy': metrics.Accuracy(),
            'Precision': metrics.Precision(),
            'Recall': metrics.Recall(),
            'F1': metrics.Fmeasure()
        }
        test_io_history = torchutils.predictSamples(
            model.to(device=device),
            test_loader,
            criterion=criterion,
            device=device,
            metrics=metric_dict,
            data_labeled=True,
            update_model=False,
            seq_as_batch=train_params['seq_as_batch'],
            return_io_history=True)
        metric_str = '  '.join(str(m) for m in metric_dict.values())
        logger.info('[TST]  ' + metric_str)

        utils.writeResults(results_file, metric_dict, sweep_param_name,
                           model_params)

        if plot_predictions:
            # imu.plot_prediction_eg(test_io_history, fig_dir, fig_type=fig_type, **viz_params)
            imu.plot_prediction_eg(test_io_history, fig_dir, **viz_params)

        def saveTrialData(pred_seq, score_seq, feat_seq, label_seq, trial_id):
            saveVariable(pred_seq.cpu().numpy(),
                         f'trial={trial_id}_pred-label-seq')
            saveVariable(score_seq.cpu().numpy(),
                         f'trial={trial_id}_score-seq')
            saveVariable(label_seq.cpu().numpy(),
                         f'trial={trial_id}_true-label-seq')

        for io in test_io_history:
            saveTrialData(*io)

        saveVariable(train_ids, f'cvfold={cv_index}_train-ids')
        saveVariable(test_ids, f'cvfold={cv_index}_test-ids')
        saveVariable(val_ids, f'cvfold={cv_index}_val-ids')
        saveVariable(train_epoch_log,
                     f'cvfold={cv_index}_{model_name}-train-epoch-log')
        saveVariable(val_epoch_log,
                     f'cvfold={cv_index}_{model_name}-val-epoch-log')
        saveVariable(metric_dict,
                     f'cvfold={cv_index}_{model_name}-metric-dict')
        saveVariable(model, f'cvfold={cv_index}_{model_name}-best')

        model.load_state_dict(last_model_wts)
        saveVariable(model, f'cvfold={cv_index}_{model_name}-last')

        torchutils.plotEpochLog(train_epoch_log,
                                subfig_size=(10, 2.5),
                                title='Training performance',
                                fn=os.path.join(
                                    fig_dir,
                                    f'cvfold={cv_index}_train-plot.png'))

        if val_epoch_log:
            torchutils.plotEpochLog(val_epoch_log,
                                    subfig_size=(10, 2.5),
                                    title='Heldout performance',
                                    fn=os.path.join(
                                        fig_dir,
                                        f'cvfold={cv_index}_val-plot.png'))
Exemple #3
0
def main(out_dir=None,
         data_dir=None,
         model_name=None,
         pretrained_model_dir=None,
         gpu_dev_id=None,
         batch_size=None,
         learning_rate=None,
         independent_signals=None,
         active_only=None,
         model_params={},
         cv_params={},
         train_params={},
         viz_params={},
         plot_predictions=None,
         results_file=None,
         sweep_param_name=None,
         label_mapping=None,
         eval_label_mapping=None):

    data_dir = os.path.expanduser(data_dir)
    out_dir = os.path.expanduser(out_dir)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    logger = utils.setupRootLogger(filename=os.path.join(out_dir, 'log.txt'))

    if results_file is None:
        results_file = os.path.join(out_dir, 'results.csv')
    else:
        results_file = os.path.expanduser(results_file)

    fig_dir = os.path.join(out_dir, 'figures')
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    out_data_dir = os.path.join(out_dir, 'data')
    if not os.path.exists(out_data_dir):
        os.makedirs(out_data_dir)

    def saveVariable(var, var_name):
        joblib.dump(var, os.path.join(out_data_dir, f'{var_name}.pkl'))

    def loadAll(seq_ids, var_name, data_dir):
        def loadOne(seq_id):
            fn = os.path.join(data_dir, f'trial={seq_id}_{var_name}')
            return joblib.load(fn)

        return tuple(map(loadOne, seq_ids))

    # Load data
    trial_ids = utils.getUniqueIds(data_dir, prefix='trial=', to_array=True)
    feature_seqs = loadAll(trial_ids, 'feature-seq.pkl', data_dir)
    label_seqs = loadAll(trial_ids, 'label-seq.pkl', data_dir)

    device = torchutils.selectDevice(gpu_dev_id)

    if label_mapping is not None:

        def map_labels(labels):
            for i, j in label_mapping.items():
                labels[labels == i] = j
            return labels

        label_seqs = tuple(map(map_labels, label_seqs))

    # Define cross-validation folds
    dataset_size = len(trial_ids)
    cv_folds = utils.makeDataSplits(dataset_size, **cv_params)

    def getSplit(split_idxs):
        split_data = tuple(
            tuple(s[i] for i in split_idxs)
            for s in (feature_seqs, label_seqs, trial_ids))
        return split_data

    for cv_index, cv_splits in enumerate(cv_folds):
        if pretrained_model_dir is not None:

            def loadFromPretrain(fn):
                return joblib.load(
                    os.path.join(pretrained_model_dir, f"{fn}.pkl"))

            model = loadFromPretrain(f'cvfold={cv_index}_{model_name}-best')
            train_ids = loadFromPretrain(f'cvfold={cv_index}_train-ids')
            val_ids = loadFromPretrain(f'cvfold={cv_index}_val-ids')
            test_ids = tuple(i for i in trial_ids
                             if i not in (train_ids + val_ids))
            test_idxs = tuple(trial_ids.tolist().index(i) for i in test_ids)
            test_data = getSplit(test_idxs)

            if independent_signals:
                criterion = torch.nn.CrossEntropyLoss()
                labels_dtype = torch.long
                test_data = splitSeqs(*test_data, active_only=False)
            else:
                # FIXME
                # criterion = torch.nn.BCEWithLogitsLoss()
                # labels_dtype = torch.float
                criterion = torch.nn.CrossEntropyLoss()
                labels_dtype = torch.long

            test_feats, test_labels, test_ids = test_data
            test_set = torchutils.SequenceDataset(test_feats,
                                                  test_labels,
                                                  device=device,
                                                  labels_dtype=labels_dtype,
                                                  seq_ids=test_ids,
                                                  transpose_data=True)
            test_loader = torch.utils.data.DataLoader(test_set,
                                                      batch_size=batch_size,
                                                      shuffle=False)

            # Test model
            metric_dict = {
                'Avg Loss': metrics.AverageLoss(),
                'Accuracy': metrics.Accuracy(),
                'Precision': metrics.Precision(),
                'Recall': metrics.Recall(),
                'F1': metrics.Fmeasure()
            }
            test_io_history = torchutils.predictSamples(
                model.to(device=device),
                test_loader,
                criterion=criterion,
                device=device,
                metrics=metric_dict,
                data_labeled=True,
                update_model=False,
                seq_as_batch=train_params['seq_as_batch'],
                return_io_history=True)
            if independent_signals:
                test_io_history = tuple(joinSeqs(test_io_history))

            metric_str = '  '.join(str(m) for m in metric_dict.values())
            logger.info('[TST]  ' + metric_str)

            d = {k: v.value for k, v in metric_dict.items()}
            utils.writeResults(results_file, d, sweep_param_name, model_params)

            if plot_predictions:
                imu.plot_prediction_eg(test_io_history, fig_dir, **viz_params)

            def saveTrialData(pred_seq, score_seq, feat_seq, label_seq,
                              trial_id):
                if label_mapping is not None:

                    def dup_score_cols(scores):
                        num_cols = scores.shape[-1] + len(label_mapping)
                        col_idxs = torch.arange(num_cols)
                        for i, j in label_mapping.items():
                            col_idxs[i] = j
                        return scores[..., col_idxs]

                    score_seq = dup_score_cols(score_seq)
                saveVariable(pred_seq.cpu().numpy(),
                             f'trial={trial_id}_pred-label-seq')
                saveVariable(score_seq.cpu().numpy(),
                             f'trial={trial_id}_score-seq')
                saveVariable(label_seq.cpu().numpy(),
                             f'trial={trial_id}_true-label-seq')

            for io in test_io_history:
                saveTrialData(*io)
            continue

        train_data, val_data, test_data = tuple(map(getSplit, cv_splits))

        if independent_signals:
            criterion = torch.nn.CrossEntropyLoss()
            labels_dtype = torch.long
            split_ = functools.partial(splitSeqs, active_only=active_only)
            train_data = split_(*train_data)
            val_data = split_(*val_data)
            test_data = splitSeqs(*test_data, active_only=False)
        else:
            # FIXME
            # criterion = torch.nn.BCEWithLogitsLoss()
            # labels_dtype = torch.float
            criterion = torch.nn.CrossEntropyLoss()
            labels_dtype = torch.long

        train_feats, train_labels, train_ids = train_data
        train_set = torchutils.SequenceDataset(train_feats,
                                               train_labels,
                                               device=device,
                                               labels_dtype=labels_dtype,
                                               seq_ids=train_ids,
                                               transpose_data=True)
        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=batch_size,
                                                   shuffle=True)

        test_feats, test_labels, test_ids = test_data
        test_set = torchutils.SequenceDataset(test_feats,
                                              test_labels,
                                              device=device,
                                              labels_dtype=labels_dtype,
                                              seq_ids=test_ids,
                                              transpose_data=True)
        test_loader = torch.utils.data.DataLoader(test_set,
                                                  batch_size=batch_size,
                                                  shuffle=False)

        val_feats, val_labels, val_ids = val_data
        val_set = torchutils.SequenceDataset(val_feats,
                                             val_labels,
                                             device=device,
                                             labels_dtype=labels_dtype,
                                             seq_ids=val_ids,
                                             transpose_data=True)
        val_loader = torch.utils.data.DataLoader(val_set,
                                                 batch_size=batch_size,
                                                 shuffle=True)

        logger.info(
            f'CV fold {cv_index + 1} / {len(cv_folds)}: {len(trial_ids)} total '
            f'({len(train_ids)} train, {len(val_ids)} val, {len(test_ids)} test)'
        )

        input_dim = train_set.num_obsv_dims
        output_dim = train_set.num_label_types
        if model_name == 'linear':
            model = torchutils.LinearClassifier(
                input_dim, output_dim, **model_params).to(device=device)
        elif model_name == 'conv':
            model = ConvClassifier(input_dim, output_dim,
                                   **model_params).to(device=device)
        elif model_name == 'TCN':
            model = TcnClassifier(input_dim, output_dim, **model_params)
        else:
            raise AssertionError()

        train_epoch_log = collections.defaultdict(list)
        val_epoch_log = collections.defaultdict(list)
        metric_dict = {
            'Avg Loss': metrics.AverageLoss(),
            'Accuracy': metrics.Accuracy(),
            'Precision': metrics.Precision(),
            'Recall': metrics.Recall(),
            'F1': metrics.Fmeasure()
        }

        optimizer_ft = torch.optim.Adam(model.parameters(),
                                        lr=learning_rate,
                                        betas=(0.9, 0.999),
                                        eps=1e-08,
                                        weight_decay=0,
                                        amsgrad=False)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft,
                                                       step_size=1,
                                                       gamma=1.00)

        model, last_model_wts = torchutils.trainModel(
            model,
            criterion,
            optimizer_ft,
            lr_scheduler,
            train_loader,
            val_loader,
            device=device,
            metrics=metric_dict,
            train_epoch_log=train_epoch_log,
            val_epoch_log=val_epoch_log,
            **train_params)

        # Test model
        metric_dict = {
            'Avg Loss': metrics.AverageLoss(),
            'Accuracy': metrics.Accuracy(),
            'Precision': metrics.Precision(),
            'Recall': metrics.Recall(),
            'F1': metrics.Fmeasure()
        }
        test_io_history = torchutils.predictSamples(
            model.to(device=device),
            test_loader,
            criterion=criterion,
            device=device,
            metrics=metric_dict,
            data_labeled=True,
            update_model=False,
            seq_as_batch=train_params['seq_as_batch'],
            return_io_history=True)
        if independent_signals:
            test_io_history = tuple(joinSeqs(test_io_history))

        metric_str = '  '.join(str(m) for m in metric_dict.values())
        logger.info('[TST]  ' + metric_str)

        d = {k: v.value for k, v in metric_dict.items()}
        utils.writeResults(results_file, d, sweep_param_name, model_params)

        if plot_predictions:
            # imu.plot_prediction_eg(test_io_history, fig_dir, fig_type=fig_type, **viz_params)
            imu.plot_prediction_eg(test_io_history, fig_dir, **viz_params)

        def saveTrialData(pred_seq, score_seq, feat_seq, label_seq, trial_id):
            if label_mapping is not None:

                def dup_score_cols(scores):
                    num_cols = scores.shape[-1] + len(label_mapping)
                    col_idxs = torch.arange(num_cols)
                    for i, j in label_mapping.items():
                        col_idxs[i] = j
                    return scores[..., col_idxs]

                score_seq = dup_score_cols(score_seq)
            saveVariable(pred_seq.cpu().numpy(),
                         f'trial={trial_id}_pred-label-seq')
            saveVariable(score_seq.cpu().numpy(),
                         f'trial={trial_id}_score-seq')
            saveVariable(label_seq.cpu().numpy(),
                         f'trial={trial_id}_true-label-seq')

        for io in test_io_history:
            saveTrialData(*io)

        saveVariable(train_ids, f'cvfold={cv_index}_train-ids')
        saveVariable(test_ids, f'cvfold={cv_index}_test-ids')
        saveVariable(val_ids, f'cvfold={cv_index}_val-ids')
        saveVariable(train_epoch_log,
                     f'cvfold={cv_index}_{model_name}-train-epoch-log')
        saveVariable(val_epoch_log,
                     f'cvfold={cv_index}_{model_name}-val-epoch-log')
        saveVariable(metric_dict,
                     f'cvfold={cv_index}_{model_name}-metric-dict')
        saveVariable(model, f'cvfold={cv_index}_{model_name}-best')

        model.load_state_dict(last_model_wts)
        saveVariable(model, f'cvfold={cv_index}_{model_name}-last')

        torchutils.plotEpochLog(train_epoch_log,
                                subfig_size=(10, 2.5),
                                title='Training performance',
                                fn=os.path.join(
                                    fig_dir,
                                    f'cvfold={cv_index}_train-plot.png'))

        if val_epoch_log:
            torchutils.plotEpochLog(val_epoch_log,
                                    subfig_size=(10, 2.5),
                                    title='Heldout performance',
                                    fn=os.path.join(
                                        fig_dir,
                                        f'cvfold={cv_index}_val-plot.png'))

        if eval_label_mapping is not None:
            metric_dict = {
                'Avg Loss': metrics.AverageLoss(),
                'Accuracy': metrics.Accuracy(),
                'Precision': metrics.Precision(),
                'Recall': metrics.Recall(),
                'F1': metrics.Fmeasure()
            }
            test_io_history = torchutils.predictSamples(
                model.to(device=device),
                test_loader,
                criterion=criterion,
                device=device,
                metrics=metric_dict,
                data_labeled=True,
                update_model=False,
                seq_as_batch=train_params['seq_as_batch'],
                return_io_history=True,
                label_mapping=eval_label_mapping)
            if independent_signals:
                test_io_history = joinSeqs(test_io_history)
            metric_str = '  '.join(str(m) for m in metric_dict.values())
            logger.info('[TST]  ' + metric_str)
Exemple #4
0
def main(out_dir=None,
         data_dir=None,
         scores_dir=None,
         model_name=None,
         results_file=None,
         sweep_param_name=None,
         independent_signals=None,
         active_only=None,
         label_mapping=None,
         eval_label_mapping=None,
         pre_init_pw=None,
         model_params={},
         cv_params={},
         train_params={},
         viz_params={},
         plot_predictions=None):

    data_dir = os.path.expanduser(data_dir)
    out_dir = os.path.expanduser(out_dir)

    logger = utils.setupRootLogger(filename=os.path.join(out_dir, 'log.txt'))

    if results_file is None:
        results_file = os.path.join(out_dir, f'results.csv')
    else:
        results_file = os.path.expanduser(results_file)

    fig_dir = os.path.join(out_dir, 'figures')
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    out_data_dir = os.path.join(out_dir, 'data')
    if not os.path.exists(out_data_dir):
        os.makedirs(out_data_dir)

    def loadVariable(var_name):
        return joblib.load(os.path.join(data_dir, f'{var_name}.pkl'))

    def saveVariable(var, var_name):
        joblib.dump(var, os.path.join(out_data_dir, f'{var_name}.pkl'))

    # Load data
    def loadAll(seq_ids, var_name, data_dir):
        def loadOne(seq_id):
            fn = os.path.join(data_dir, f'trial={seq_id}_{var_name}')
            return joblib.load(fn)

        return tuple(map(loadOne, seq_ids))

    # Load data
    trial_ids = utils.getUniqueIds(data_dir, prefix='trial=')
    feature_seqs = loadAll(trial_ids, 'feature-seq.pkl', data_dir)
    label_seqs = loadAll(trial_ids, 'label-seq.pkl', data_dir)

    if scores_dir is not None:
        scores_dir = os.path.expanduser(scores_dir)
        feature_seqs = tuple(
            joblib.load(
                os.path.join(scores_dir, f'trial={trial_id}_score-seq.pkl')).
            swapaxes(0, 1) for trial_id in trial_ids)

    if label_mapping is not None:

        def map_labels(labels):
            for i, j in label_mapping.items():
                labels[labels == i] = j
            return labels

        label_seqs = tuple(map(map_labels, label_seqs))
        if scores_dir is not None:
            num_labels = feature_seqs[0].shape[-1]
            idxs = [i for i in range(num_labels) if i not in label_mapping]
            feature_seqs = tuple(x[..., idxs] for x in feature_seqs)

    # Define cross-validation folds
    dataset_size = len(trial_ids)
    cv_folds = utils.makeDataSplits(dataset_size, **cv_params)

    metric_dict = {'accuracy': [], 'edit_score': [], 'overlap_score': []}

    def getSplit(split_idxs):
        split_data = tuple(
            tuple(s[i] for i in split_idxs)
            for s in (feature_seqs, label_seqs, trial_ids))
        return split_data

    for cv_index, cv_splits in enumerate(cv_folds):
        train_data, val_data, test_data = tuple(map(getSplit, cv_splits))

        if independent_signals:
            split_ = functools.partial(splitSeqs, active_only=active_only)
            train_samples, train_labels, train_ids = split_(*train_data)
            val_samples, val_labels, val_ids = split_(*val_data)
            test_samples, test_labels, test_ids = splitSeqs(*test_data,
                                                            active_only=False)

            # Transpose input data so they have shape (num_features, num_samples),
            # to conform with LCTM interface
            train_samples = preprocess(train_samples)
            test_samples = preprocess(test_samples)
            val_samples = preprocess(val_samples)

        else:
            raise NotImplementedError()

        logger.info(
            f'CV fold {cv_index + 1}: {len(trial_ids)} total '
            f'({len(train_ids)} train, {len(val_ids)} val, {len(test_ids)} test)'
        )

        model = getattr(LCTM.models, model_name)(**model_params)

        if pre_init_pw:
            pretrain = train_params.get('pretrain', True)
            model = pre_init(model,
                             train_samples,
                             train_labels,
                             pretrain=pretrain)
        else:
            model.fit(train_samples, train_labels, **train_params)
            # FIXME: Is this even necessary?
            if model_params.get('inference', None) == 'segmental':
                model.max_segs = LCTM.utils.max_seg_count(train_labels)

        plot_weights(model,
                     fn=os.path.join(
                         fig_dir,
                         f"cvfold={cv_index}_model-weights-trained.png"))
        plot_train(model.logger.objectives,
                   fn=os.path.join(fig_dir,
                                   f"cvfold={cv_index}_train-loss.png"))

        # Test model
        pred_labels = model.predict(test_samples)
        # test_samples = tuple(map(lambda x: x.swapaxes(0, 1), test_samples))
        test_io_history = tuple(
            zip([pred_labels], [test_samples], [test_samples], [test_labels],
                [test_ids]))

        if independent_signals:
            test_io_history = tuple(joinSeqs(test_io_history))

        for name in metric_dict.keys():
            value = getattr(LCTM.metrics, name)(pred_labels, test_labels)
            metric_dict[name] += [value]
        metric_str = '  '.join(f"{k}: {v[-1]:.1f}%"
                               for k, v in metric_dict.items())
        logger.info('[TST]  ' + metric_str)

        all_labels = np.hstack(test_labels)
        label_hist = utils.makeHistogram(len(np.unique(all_labels)),
                                         all_labels,
                                         normalize=True)
        logger.info(f'Label distribution: {label_hist}')

        d = {k: v[-1] / 100 for k, v in metric_dict.items()}
        utils.writeResults(results_file, d, sweep_param_name, model_params)

        if plot_predictions:
            imu.plot_prediction_eg(test_io_history, fig_dir, **viz_params)

        def saveTrialData(pred_seq, score_seq, feat_seq, label_seq, trial_id):
            if False:  # label_mapping is not None:

                def dup_score_cols(scores):
                    num_cols = scores.shape[-1] + len(label_mapping)
                    col_idxs = np.arange(num_cols)
                    for i, j in label_mapping.items():
                        col_idxs[i] = j
                    return scores[..., col_idxs]

                score_seq = dup_score_cols(score_seq)
            saveVariable(pred_seq, f'trial={trial_id}_pred-label-seq')
            saveVariable(score_seq, f'trial={trial_id}_score-seq')
            saveVariable(label_seq, f'trial={trial_id}_true-label-seq')

        for io in test_io_history:
            saveTrialData(*io)

        saveVariable(train_ids, f'cvfold={cv_index}_train-ids')
        saveVariable(test_ids, f'cvfold={cv_index}_test-ids')
        saveVariable(val_ids, f'cvfold={cv_index}_val-ids')
        saveVariable(metric_dict,
                     f'cvfold={cv_index}_{model_name}-metric-dict')
        saveVariable(model, f'cvfold={cv_index}_{model_name}-best')

        if eval_label_mapping is not None:

            def map_labels(labels):
                labels = labels.copy()
                for i, j in eval_label_mapping.items():
                    labels[labels == i] = j
                return labels

            pred_labels = tuple(map(map_labels, pred_labels))
            test_labels = tuple(map(map_labels, test_labels))
            for name in metric_dict.keys():
                value = getattr(LCTM.metrics, name)(pred_labels, test_labels)
                metric_dict[name] += [value]
            metric_str = '  '.join(f"{k}: {v[-1]:.1f}%"
                                   for k, v in metric_dict.items())
            logger.info('[TST]  ' + metric_str)

            all_labels = np.hstack(test_labels)
            label_hist = utils.makeHistogram(len(np.unique(all_labels)),
                                             all_labels,
                                             normalize=True)
            logger.info(f'Label distribution: {label_hist}')
Exemple #5
0
def main(out_dir=None,
         data_dir=None,
         scores_dir=None,
         model_name=None,
         model_params={},
         results_file=None,
         sweep_param_name=None,
         cv_params={},
         viz_params={},
         plot_predictions=None):

    data_dir = os.path.expanduser(data_dir)
    out_dir = os.path.expanduser(out_dir)

    logger = utils.setupRootLogger(filename=os.path.join(out_dir, 'log.txt'))

    if results_file is None:
        results_file = os.path.join(out_dir, f'results.csv')
    else:
        results_file = os.path.expanduser(results_file)

    fig_dir = os.path.join(out_dir, 'figures')
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    out_data_dir = os.path.join(out_dir, 'data')
    if not os.path.exists(out_data_dir):
        os.makedirs(out_data_dir)

    def loadVariable(var_name):
        return joblib.load(os.path.join(data_dir, f'{var_name}.pkl'))

    def saveVariable(var, var_name):
        joblib.dump(var, os.path.join(out_data_dir, f'{var_name}.pkl'))

    # Load data
    trial_ids = loadVariable('trial_ids')
    feature_seqs = loadVariable('imu_sample_seqs')
    label_seqs = loadVariable('imu_label_seqs')

    if scores_dir is not None:
        scores_dir = os.path.expanduser(scores_dir)
        feature_seqs = tuple(
            joblib.load(
                os.path.join(scores_dir, f'trial={trial_id}_score-seq.pkl')).
            swapaxes(0, 1) for trial_id in trial_ids)

    # Define cross-validation folds
    dataset_size = len(trial_ids)
    cv_folds = utils.makeDataSplits(dataset_size, **cv_params)

    metric_dict = {'accuracy': [], 'edit_score': [], 'overlap_score': []}

    def getSplit(split_idxs):
        split_data = tuple(
            tuple(s[i] for i in split_idxs)
            for s in (feature_seqs, label_seqs, trial_ids))
        return split_data

    for cv_index, cv_splits in enumerate(cv_folds):
        train_data, val_data, test_data = tuple(map(getSplit, cv_splits))

        train_ids = train_data[-1]
        test_ids = test_data[-1]
        val_ids = val_data[-1]

        logger.info(
            f'CV fold {cv_index + 1}: {len(trial_ids)} total '
            f'({len(train_ids)} train, {len(val_ids)} val, {len(test_ids)} test)'
        )

        for name in metric_dict.keys():
            value = None  # FIXME
            metric_dict[name] += [value]
        metric_str = '  '.join(f"{k}: {v[-1]:.1f}%"
                               for k, v in metric_dict.items())
        logger.info('[TST]  ' + metric_str)

        d = {k: v[-1] for k, v in metric_dict.items()}
        utils.writeResults(results_file, d, sweep_param_name, model_params)

        test_io_history = None  # FIXME

        if plot_predictions:
            imu.plot_prediction_eg(test_io_history, fig_dir, **viz_params)

        def saveTrialData(pred_seq, score_seq, feat_seq, label_seq, trial_id):
            saveVariable(pred_seq, f'trial={trial_id}_pred-label-seq')
            saveVariable(score_seq, f'trial={trial_id}_score-seq')
            saveVariable(label_seq, f'trial={trial_id}_true-label-seq')

        for io in test_io_history:
            saveTrialData(*io)