示例#1
0
def parseActions(assembly_actions, num_frames, vocab):
    def makeJoint(part1, part2):
        return tuple(sorted([part1, part2]))

    def updateAssembly(assembly, joint):
        return tuple(sorted(cur_assembly + (joint, )))

    assembly_index_seq = np.zeros(num_frames, dtype=int)

    cur_assembly = tuple()
    prev_start = -1
    prev_end = -1
    for i, row in assembly_actions.iterrows():
        if row.start != prev_start or row.end != prev_end:
            cur_assembly_index = utils.getIndex(cur_assembly, vocab)
            assembly_index_seq[prev_end:row.end + 1] = cur_assembly_index
            prev_start = row.start
            prev_end = row.end

        if row.action == 'connect':
            joint = makeJoint(row.part1, row.part2)
            cur_assembly = updateAssembly(cur_assembly, joint)
        elif row.action == 'pin':
            continue
        else:
            raise ValueError()
    cur_assembly_index = utils.getIndex(cur_assembly, vocab)
    assembly_index_seq[prev_end:] = cur_assembly_index
    return assembly_index_seq
示例#2
0
    def gen_kinem_labels(actions):
        state = lib_asm.Assembly()
        action_segs = tuple(gen_segments(actions))
        for start, end in action_segs:
            segment = actions.loc[start:end]
            for row in segment.itertuples(index=False):
                # label = row.label
                # i_start = row.start
                # i_end = row.end
                arg1 = row.arg1
                arg2 = row.arg2

                parent = lib_asm.Link(arg1)
                child = lib_asm.Link(arg2)
                joint = lib_asm.Joint((arg1, arg2), 'rigid', arg1, arg2)
                state = state.add_joint(
                    joint,
                    parent,
                    child,
                    # directed=False,
                    in_place=False)
                joint = lib_asm.Joint((arg2, arg1), 'rigid', arg2, arg1)
                state = state.add_joint(
                    joint,
                    child,
                    parent,
                    # directed=False,
                    in_place=False)
            start_idx = actions.loc[start]['start']
            end_idx = actions.loc[end]['end']
            state_idx = utils.getIndex(state, kinem_vocab)
            yield start_idx, end_idx, state_idx
示例#3
0
def actionsFromAssemblies(assembly_segs, action_vocab):
    action_segs = (
        (blockassembly.AssemblyAction(), ) +
        tuple(n - c for c, n in zip(assembly_segs[:-1], assembly_segs[1:])))
    action_index_segs = np.array(
        [utils.getIndex(a, action_vocab) for a in action_segs])
    return action_segs, action_index_segs
示例#4
0
 def loadAssemblies(self, seq_id, var_name, vocab, prefix='trial='):
     assembly_seq = utils.loadVariable(f"{prefix}{seq_id}_{var_name}",
                                       self.data_dir)
     labels = np.zeros(assembly_seq[-1].end_idx, dtype=int)
     for assembly in assembly_seq:
         i = utils.getIndex(assembly, vocab)
         labels[assembly.start_idx:assembly.end_idx] = i
     return labels
示例#5
0
def actionsFromAssemblies(assembly_index_seq, assembly_vocab, action_vocab):
    assembly_index_segs, seg_lens = utils.computeSegments(assembly_index_seq)
    assembly_segs = tuple(assembly_vocab[i] for i in assembly_index_segs)
    action_segs = (
        (lib_assembly.AssemblyAction(),)
        + tuple(n - c for c, n in zip(assembly_segs[:-1], assembly_segs[1:]))
    )
    action_index_segs = tuple(utils.getIndex(a, action_vocab) for a in action_segs)
    action_index_seq = np.array(utils.fromSegments(action_index_segs, seg_lens))

    return action_segs, action_index_seq
示例#6
0
    def update_vocabs(self):
        # Keep a reference of link names and their global vocab indices, sorted
        # in increasing order
        # self.link_names = tuple(link.name for link_id, link in self.links.items())
        # self.link_indices = np.array([
        #     utils.getIndex(name, self.link_vocab)
        #     for name in self.link_names
        # ])
        # sort_indices = np.argsort(self.link_indices)
        # self.link_names = tuple(self.link_names[i] for i in sort_indices)
        # self.link_indices = self.link_indices[sort_indices]

        # Keep a reference of joint names and their global vocab indices, sorted
        # in increasing order
        # self.joint_names = tuple(joint.name for joint_id, joint in self.joints.items())
        # self.joint_indices = np.array([
        #     utils.getIndex(name, self.joint_vocab)
        #     for name in self.joint_names
        # ])
        # sort_indices = np.argsort(self.joint_indices)
        # self.joint_names = tuple(self.joint_names[i] for i in sort_indices)
        # self.joint_indices = self.joint_indices[sort_indices]
        for link_name, link in self.links.items():
            utils.getIndex(link.name, self.link_vocab)
        for joint_name, joint in self.joints.items():
            utils.getIndex(joint.name, self.joint_vocab)
            utils.getIndex(joint.joint_type, self.joint_type_vocab)

        self.index_symmetries = {
            self.link_vocab[name]:
            tuple(self.link_vocab.get(m, -1) for m in matches)
            for name, matches in self.symmetries.items()
            if name in self.link_vocab  # FIXME
        }
示例#7
0
def makeEqClasses(assembly_vocab, part_symmetries):
    def renameLink(assembly, old_name, new_name):
        link_dict = {}
        for link in assembly.links.values():
            new_link = lib_assembly.Link(link.name.replace(old_name, new_name),
                                         pose=link.pose)
            if new_link.name in link_dict:
                if link_dict[new_link.name] != new_link:
                    raise AssertionError()
                continue
            else:
                link_dict[new_link.name] = new_link
        new_links = list(link_dict.values())

        new_joints = [
            lib_assembly.Joint(joint.name.replace(old_name, new_name),
                               joint.joint_type,
                               joint.parent_name.replace(old_name, new_name),
                               joint.child_name.replace(old_name, new_name),
                               transform=joint._transform)
            for joint in assembly.joints.values()
        ]

        new_assembly = lib_assembly.Assembly(links=new_links,
                                             joints=new_joints)
        return new_assembly

    eq_class_vocab = []
    assembly_eq_classes = []
    for vocab_index, assembly in enumerate(assembly_vocab):
        for new_name, old_names in part_symmetries.items():
            for old_name in old_names:
                assembly = renameLink(assembly, old_name, new_name)
        eq_class_index = utils.getIndex(assembly, eq_class_vocab)
        assembly_eq_classes.append(eq_class_index)

    assembly_eq_classes = np.array(assembly_eq_classes)

    eq_classes = np.zeros((len(assembly_vocab), len(eq_class_vocab)),
                          dtype=float)
    for vocab_index, eq_class_index in enumerate(assembly_eq_classes):
        eq_classes[vocab_index, eq_class_index] = 1
    eq_classes /= eq_classes.sum(axis=0)

    # import pdb; pdb.set_trace()

    return eq_classes, assembly_eq_classes, eq_class_vocab
示例#8
0
    def get_assembly_label(edge_labels):
        all_edges_match = (edge_labels == assembly_vocab_edges).all(axis=1)
        if all_edges_match.any():
            assembly_labels = all_edges_match.nonzero()[0]
            if edge_labels.any() and assembly_labels.size != 1:
                AssertionError(
                    f"{assembly_labels.size} assemblies match these edges!")
            return assembly_labels[0]

        edges = tuple(edge_vocab[keys[edge_index]][edge_label - 1]
                      for edge_index, edge_label in enumerate(edge_labels)
                      if edge_label)
        assembly = lib_assembly.union(*edges,
                                      link_vocab=link_vocab,
                                      joint_vocab=joint_vocab,
                                      joint_type_vocab=joint_type_vocab)
        assembly_label = utils.getIndex(assembly, assembly_vocab)
        return assembly_label
示例#9
0
def loadAssemblies(seq_id, vocab, data_dir):
    assembly_seq = utils.load(f"trial={seq_id}_assembly-seq", data_dir)
    # assembly_seq = joblib.load(os.path.join(data_dir, f"trial={seq_id}_assembly-seq.pkl"))
    labels = np.array(
        [utils.getIndex(assembly, vocab) for assembly in assembly_seq])
    return labels
示例#10
0
def make_labels(assembly_seq, vocab):
    labels = np.zeros(assembly_seq[-1].end_idx, dtype=int)
    for assembly in assembly_seq:
        i = utils.getIndex(assembly, vocab)
        labels[assembly.start_idx:assembly.end_idx] = i
    return labels
示例#11
0
    def gen_assembly_vocab(final_assemblies, part_categories, ref_vocab=None):
        def get_parts(joints):
            return frozenset(p for joint in joints for p in joint)

        def remove_part(assembly, part):
            removed = frozenset(joint for joint in assembly
                                if part not in joint)
            # Renumber identical parts using part_to_class
            part_names = tuple(sorted(get_parts(removed)))
            part_classes = tuple(part_to_class.get(p, p) for p in part_names)
            rename = {}
            class_to_parts = collections.defaultdict(list)
            for part_name, part_class in zip(part_names, part_classes):
                class_to_parts[part_class].append(part_name)
                if part_class in part_categories:
                    i = len(class_to_parts[part_class]) - 1
                    new_part_name = part_categories[part_class][i]
                else:
                    new_part_name = part_name
                rename[part_name] = new_part_name
            removed = frozenset(
                frozenset(rename[part] for part in joint) for joint in removed)
            return removed

        def is_possible(assembly):
            def shelf_with_tabletop(assembly):
                def replace(joint, replacee, replacer):
                    replaced = set(joint)
                    replaced.remove(replacee)
                    replaced.add(replacer)
                    return frozenset(replaced)

                for joint in assembly:
                    if 'shelf' in joint and replace(joint, 'shelf',
                                                    'table') not in joint:
                        return False
                return True

            predicates = (shelf_with_tabletop, )
            return all(x(assembly) for x in predicates)

        part_to_class = {
            part: class_name
            for class_name, parts in part_categories.items() for part in parts
        }

        final_assemblies = tuple(
            frozenset(frozenset(j) for j in joints)
            for joints in final_assemblies)

        stack = list(final_assemblies)
        assembly_vocab = set(final_assemblies)
        while stack:
            assembly = stack.pop()
            for part in get_parts(assembly):
                child = remove_part(assembly, part)
                if is_possible(child) and child not in assembly_vocab:
                    assembly_vocab.add(child)
                    stack.append(child)

        assembly_vocab = tuple(
            tuple(sorted(tuple(sorted(j)) for j in joints))
            for joints in sorted(assembly_vocab, key=len))

        if ref_vocab is not None:
            v = list(ref_vocab)
            for a in assembly_vocab:
                utils.getIndex(a, v)
            assembly_vocab = tuple(v)

        return assembly_vocab
def main(out_dir=None,
         data_dir=None,
         model_name=None,
         gpu_dev_id=None,
         batch_size=None,
         learning_rate=None,
         model_params={},
         cv_params={},
         train_params={},
         viz_params={},
         load_masks_params={},
         kornia_tfs={},
         only_edge=None,
         num_disp_imgs=None,
         results_file=None,
         sweep_param_name=None):

    data_dir = os.path.expanduser(data_dir)
    out_dir = os.path.expanduser(out_dir)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    logger = utils.setupRootLogger(filename=os.path.join(out_dir, 'log.txt'))

    if results_file is None:
        results_file = os.path.join(out_dir, 'results.csv')
    else:
        results_file = os.path.expanduser(results_file)

    fig_dir = os.path.join(out_dir, 'figures')
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    io_dir = os.path.join(fig_dir, 'model-io')
    if not os.path.exists(io_dir):
        os.makedirs(io_dir)

    out_data_dir = os.path.join(out_dir, 'data')
    if not os.path.exists(out_data_dir):
        os.makedirs(out_data_dir)

    def saveVariable(var, var_name, to_dir=out_data_dir):
        return utils.saveVariable(var, var_name, to_dir)

    trial_ids = utils.getUniqueIds(data_dir, prefix='trial=', to_array=True)

    vocab = [BlockAssembly()
             ] + [make_single_block_state(i) for i in range(len(defn.blocks))]
    for seq_id in trial_ids:
        assembly_seq = utils.loadVariable(f"trial={seq_id}_assembly-seq",
                                          data_dir)
        for assembly in assembly_seq:
            utils.getIndex(assembly, vocab)
    parts_vocab, part_labels = labels_lib.make_parts_vocab(
        vocab, lower_tri_only=True, append_to_vocab=True)

    if only_edge is not None:
        part_labels = part_labels[:, only_edge:only_edge + 1]

    logger.info(
        f"Loaded {len(trial_ids)} sequences; {len(vocab)} unique assemblies")

    saveVariable(vocab, 'vocab')
    saveVariable(parts_vocab, 'parts-vocab')
    saveVariable(part_labels, 'part-labels')

    device = torchutils.selectDevice(gpu_dev_id)

    if model_name == 'AAE':
        Dataset = sim2real.DenoisingDataset
    elif model_name == 'Resnet':
        Dataset = sim2real.RenderDataset
    elif model_name == 'Connections':
        Dataset = sim2real.ConnectionDataset
    elif model_name == 'Labeled Connections':
        Dataset = sim2real.LabeledConnectionDataset

    occlusion_masks = loadMasks(**load_masks_params)
    if occlusion_masks is not None:
        logger.info(f"Loaded {occlusion_masks.shape[0]} occlusion masks")

    def make_data(shuffle=True):
        dataset = Dataset(
            parts_vocab,
            part_labels,
            vocab,
            device=device,
            occlusion_masks=occlusion_masks,
            kornia_tfs=kornia_tfs,
        )
        data_loader = torch.utils.data.DataLoader(dataset,
                                                  batch_size=batch_size,
                                                  shuffle=shuffle)
        return dataset, data_loader

    for cv_index, cv_splits in enumerate(range(1)):
        cv_str = f"cvfold={cv_index}"

        train_set, train_loader = make_data(shuffle=True)
        test_set, test_loader = make_data(shuffle=False)
        val_set, val_loader = make_data(shuffle=True)

        if model_name == 'AAE':
            model = sim2real.AugmentedAutoEncoder(train_set.data_shape,
                                                  train_set.num_classes)
            criterion = torchutils.BootstrappedCriterion(
                0.25,
                base_criterion=torch.nn.functional.mse_loss,
            )
            metric_names = ('Reciprocal Loss', )
        elif model_name == 'Resnet':
            model = sim2real.ImageClassifier(train_set.num_classes,
                                             **model_params)
            criterion = torch.nn.CrossEntropyLoss()
            metric_names = ('Loss', 'Accuracy')
        elif model_name == 'Connections':
            model = sim2real.ConnectionClassifier(train_set.label_shape[0],
                                                  **model_params)
            criterion = torch.nn.BCEWithLogitsLoss()
            metric_names = ('Loss', 'Accuracy', 'Precision', 'Recall', 'F1')
        elif model_name == 'Labeled Connections':
            out_dim = int(part_labels.max()) + 1
            num_vertices = len(defn.blocks)
            edges = np.column_stack(np.tril_indices(num_vertices, k=-1))
            if only_edge is not None:
                edges = edges[only_edge:only_edge + 1]
            model = sim2real.LabeledConnectionClassifier(
                out_dim, num_vertices, edges, **model_params)
            if only_edge is not None:
                logger.info(f"Class freqs: {train_set.class_freqs}")
                # criterion = torch.nn.CrossEntropyLoss(weight=1 / train_set.class_freqs[:, 0])
                criterion = torch.nn.CrossEntropyLoss()
            else:
                criterion = torch.nn.CrossEntropyLoss()
            # criterion = torchutils.BootstrappedCriterion(
            #     0.25, base_criterion=torch.nn.functional.cross_entropy,
            # )
            metric_names = ('Loss', 'Accuracy', 'Precision', 'Recall', 'F1')

        model = model.to(device=device)

        optimizer_ft = torch.optim.Adam(model.parameters(),
                                        lr=learning_rate,
                                        betas=(0.9, 0.999),
                                        eps=1e-08,
                                        weight_decay=0,
                                        amsgrad=False)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft,
                                                       step_size=1,
                                                       gamma=1.00)

        train_epoch_log = collections.defaultdict(list)
        val_epoch_log = collections.defaultdict(list)
        metric_dict = {name: metrics.makeMetric(name) for name in metric_names}
        model, last_model_wts = torchutils.trainModel(
            model,
            criterion,
            optimizer_ft,
            lr_scheduler,
            train_loader,
            val_loader,
            device=device,
            metrics=metric_dict,
            train_epoch_log=train_epoch_log,
            val_epoch_log=val_epoch_log,
            **train_params)

        # Test model
        metric_dict = {name: metrics.makeMetric(name) for name in metric_names}
        test_io_batches = torchutils.predictSamples(
            model.to(device=device),
            test_loader,
            criterion=criterion,
            device=device,
            metrics=metric_dict,
            data_labeled=True,
            update_model=False,
            seq_as_batch=train_params['seq_as_batch'],
            return_io_history=True)
        metric_str = '  '.join(str(m) for m in metric_dict.values())
        logger.info('[TST]  ' + metric_str)
        utils.writeResults(results_file, metric_dict, sweep_param_name,
                           model_params)

        for pred_seq, score_seq, feat_seq, label_seq, trial_id in test_io_batches:
            trial_str = f"trial={trial_id}"
            saveVariable(pred_seq.cpu().numpy(), f'{trial_str}_pred-label-seq')
            saveVariable(score_seq.cpu().numpy(), f'{trial_str}_score-seq')
            saveVariable(label_seq.cpu().numpy(),
                         f'{trial_str}_true-label-seq')

        saveVariable(model, f'{cv_str}_model-best')

        if train_epoch_log:
            torchutils.plotEpochLog(train_epoch_log,
                                    subfig_size=(10, 2.5),
                                    title='Training performance',
                                    fn=os.path.join(
                                        fig_dir, f'{cv_str}_train-plot.png'))

        if val_epoch_log:
            torchutils.plotEpochLog(val_epoch_log,
                                    subfig_size=(10, 2.5),
                                    title='Heldout performance',
                                    fn=os.path.join(fig_dir,
                                                    f'{cv_str}_val-plot.png'))

        if num_disp_imgs is not None:
            model.plotBatches(test_io_batches, io_dir, dataset=test_set)
示例#13
0
def main(out_dir=None,
         preds_dir=None,
         data_dir=None,
         metric_names=None,
         plot_output=None,
         results_file=None,
         sweep_param_name=None):

    if metric_names is None:
        metric_names = ('accuracy', 'edit_score', 'overlap_score')

    preds_dir = os.path.expanduser(preds_dir)

    out_dir = os.path.expanduser(out_dir)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    out_data_dir = os.path.join(out_dir, 'data')
    if not os.path.exists(out_data_dir):
        os.makedirs(out_data_dir)

    fig_dir = os.path.join(out_dir, 'figures')
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    logger = utils.setupRootLogger(filename=os.path.join(out_dir, 'log.txt'))

    logger.info(f"Writing to: {out_dir}")

    if results_file is None:
        results_file = os.path.join(out_dir, 'results.csv')
        if os.path.exists(results_file):
            os.remove(results_file)
    else:
        results_file = os.path.expanduser(results_file)

    def saveVariable(var, var_name):
        joblib.dump(var, os.path.join(out_data_dir, f'{var_name}.pkl'))

    def loadAll(seq_ids, var_name, data_dir):
        def loadOne(seq_id):
            fn = os.path.join(data_dir, f'trial-{seq_id}_{var_name}')
            return joblib.load(fn)

        return tuple(map(loadOne, seq_ids))

    # assembly_vocab = joblib.load(os.path.join(data_dir, 'assembly-vocab.pkl'))

    trial_ids = utils.getUniqueIds(preds_dir, prefix='trial-')
    pred_seqs = loadAll(trial_ids, 'pred-state-seq.pkl', preds_dir)
    true_seqs = loadAll(trial_ids, 'true-state-seq.pkl', preds_dir)

    action_vocab = []
    state_vocab = []
    for i, trial_id in enumerate(trial_ids):
        logger.info(f"VIDEO {trial_id}:")

        pred_state_segs = pred_seqs[i]
        pred_state_index_seq = np.array(
            [utils.getIndex(s, state_vocab) for s in pred_state_segs])
        pred_action_segs, pred_action_index_seq = actionsFromAssemblies(
            pred_seqs[i], action_vocab)

        true_state_segs = true_seqs[i]
        true_state_index_seq = np.array(
            [utils.getIndex(s, state_vocab) for s in true_state_segs])
        true_action_segs, true_action_index_seq = actionsFromAssemblies(
            true_seqs[i], action_vocab)

        metric_dict = {}
        for name in metric_names:
            key = f"{name}_action"
            value = getattr(LCTM.metrics, name)(pred_action_index_seq,
                                                true_action_index_seq) / 100
            metric_dict[key] = value
            logger.info(f"  {key}: {value * 100:.1f}%")

            key = f"{name}_state"
            value = getattr(LCTM.metrics, name)(pred_state_index_seq,
                                                true_state_index_seq) / 100
            metric_dict[key] = value
            logger.info(f"  {key}: {value * 100:.1f}%")

        utils.writeResults(results_file, metric_dict, sweep_param_name, {})

    assembly_fig_dir = os.path.join(fig_dir, 'assemblies')
    if not os.path.exists(assembly_fig_dir):
        os.makedirs(assembly_fig_dir)
    for i, assembly in enumerate(state_vocab):
        assembly.draw(assembly_fig_dir, i)

    action_fig_dir = os.path.join(fig_dir, 'actions')
    if not os.path.exists(action_fig_dir):
        os.makedirs(action_fig_dir)
    for i, action in enumerate(action_vocab):
        action.draw(action_fig_dir, i)

    assembly_paths_dir = os.path.join(fig_dir, 'assembly-seq-imgs')
    if not os.path.exists(assembly_paths_dir):
        os.makedirs(assembly_paths_dir)

    action_paths_dir = os.path.join(fig_dir, 'action-seq-imgs')
    if not os.path.exists(action_paths_dir):
        os.makedirs(action_paths_dir)

    for i, trial_id in enumerate(trial_ids):

        pred_state_segs = pred_seqs[i]
        pred_state_index_seq = np.array(
            [utils.getIndex(s, state_vocab) for s in pred_state_segs])
        pred_action_segs, pred_action_index_seq = actionsFromAssemblies(
            pred_seqs[i], action_vocab)
        assemblystats.drawPath(pred_action_index_seq, trial_id,
                               f"trial={trial_id}_pred-seq", action_paths_dir,
                               action_fig_dir)

        assemblystats.drawPath(pred_state_index_seq, trial_id,
                               f"trial={trial_id}_pred-seq",
                               assembly_paths_dir, assembly_fig_dir)

        true_state_segs = true_seqs[i]
        true_state_index_seq = np.array(
            [utils.getIndex(s, state_vocab) for s in true_state_segs])
        true_action_segs, true_action_index_seq = actionsFromAssemblies(
            true_seqs[i], action_vocab)
        assemblystats.drawPath(true_action_index_seq, trial_id,
                               f"trial={trial_id}_true-seq", action_paths_dir,
                               action_fig_dir)

        assemblystats.drawPath(true_state_index_seq, trial_id,
                               f"trial={trial_id}_true-seq",
                               assembly_paths_dir, assembly_fig_dir)
示例#14
0
 def makeEvent(assembly_action):
     return utils.getIndex(assembly_action, action_vocab)