Пример #1
0
def _sample_nodes(model: GraphSamplingModel,
                  graph,
                  node_feature_mapping,
                  generator=None):
    labels_idx = model.sample_node_label(graph, generator)

    if len(labels_idx) != 1:
        raise NotImplementedError('Batch sampling is not implemented')

    label_idx = int(labels_idx[0])
    label = dataset.NODE_IDX_MAP_REVERSE[label_idx]

    target_type = target.TargetType.from_label(label)

    if target_type not in model.feature_dimensions:
        return [NodeOp(label)]

    if target_type in model.feature_dimensions:
        fd = model.feature_dimensions[target_type]
        numerical_features = labels_idx.new_zeros((1, len(fd)), dtype=int)

        for feat_idx in range(len(fd)):
            sample_features = model.sample_entity_features(
                graph, target_type, numerical_features, generator)
            numerical_features[0, feat_idx] = sample_features[feat_idx]

        numerical_features = numerical_features[0].cpu().tolist()
        parameters = node_feature_mapping.features_from_index(
            numerical_features, target_type)
    else:
        parameters = {}

    return [NodeOp(label, parameters)]
Пример #2
0
    def _add_node(self, label, parameters=None):
        if parameters is None:
            parameters = {}

        self.seq.append(NodeOp(label, parameters))
        self.num_nodes += 1

        if not isinstance(label, datalib.SubnodeType):
            self.last_entity_node = self.num_nodes - 1
            return

        # Add subedge
        self.add_op(
            EdgeOp(datalib.ConstraintType.Subnode,
                   (self.num_nodes - 1, self.last_entity_node)))
Пример #3
0
def test_get_sequence_dof():
    seq = [
        NodeOp(label=EntityType.External),
        NodeOp(label=EntityType.Line),
        NodeOp(label=SubnodeType.SN_Start),
        EdgeOp(label=ConstraintType.Subnode, references=(2, 1)),
        NodeOp(label=SubnodeType.SN_End),
        EdgeOp(label=ConstraintType.Subnode, references=(3, 1)),
        NodeOp(label=EntityType.Line),
        EdgeOp(label=ConstraintType.Parallel, references=(4, 1)),
        EdgeOp(label=ConstraintType.Horizontal, references=(4,)),
        EdgeOp(label=ConstraintType.Distance, references=(4, 1)),
        NodeOp(label=SubnodeType.SN_Start),
        EdgeOp(label=ConstraintType.Subnode, references=(5, 4)),
        NodeOp(label=SubnodeType.SN_End),
        EdgeOp(label=ConstraintType.Subnode, references=(6, 4)),
        NodeOp(label=EntityType.Stop)]

    dof_remaining = np.sum(get_sequence_dof(seq))
    assert dof_remaining == 5
Пример #4
0
def generate_sample(model: GraphSamplingModel,
                    max_iters,
                    node_feature_mapping,
                    edge_feature_mapping,
                    generator=None,
                    device=None):
    builder = _SeqBuilder()
    state = 'add_node'
    subnodes_to_add = None

    while len(builder.seq) < max_iters:
        graph = dataset.graph_info_from_sequence(builder.seq,
                                                 node_feature_mapping,
                                                 edge_feature_mapping)
        graph = training.load_cuda_async(graph, device)

        if state == 'add_node':
            node_op, = _sample_nodes(model, graph, node_feature_mapping,
                                     generator)
            builder.add_op(node_op)
            if node_op.label == EntityType.Stop:
                break

            subnodes_to_add = list(_get_subnodes_for_entity(node_op))
            state = 'add_edge'
        elif state == 'add_edge':
            edge_op, = _sample_edges(model, graph, edge_feature_mapping,
                                     generator)
            if edge_op is not None:
                assert max(edge_op.references) + 1 == builder.num_nodes
                builder.add_op(edge_op)
                continue

            if subnodes_to_add:
                subnode_op = NodeOp(subnodes_to_add.pop())
                builder.add_op(subnode_op)
                state = 'add_edge'
            else:
                state = 'add_node'
        else:
            assert False
    return builder.seq
Пример #5
0
 def __init__(self):
     self.seq = [NodeOp(datalib.EntityType.External)]
     self.last_entity_node = None
     self.num_nodes = 1