示例#1
0
    def setup(self):
        with open(cmd_args.data_root + '/' + cmd_args.file_list, 'r') as f:
            cur_sample_idx = 0
            for row in f:
                if cmd_args.single_sample is None or cur_sample_idx == cmd_args.single_sample:
                    filename = cmd_args.data_root + '/' + row.strip(
                    ) + '.pickle'
                    with gzip.open(filename, 'rb') as f:
                        loaded_object = pickle.load(f)
                    num_samples = len(loaded_object)
                    num_train = int(num_samples * cmd_args.train_frac)
                    local_idx = 0
                    for x in loaded_object:
                        local_idx += 1
                        if local_idx <= num_train:
                            self.train_indices.append(len(self.pg_list))
                            if cmd_args.single_sample is not None and cmd_args.single_sample == cur_sample_idx:
                                self.single_sample_train.append(
                                    len(self.pg_list))
                        else:
                            self.test_indices.append(len(self.pg_list))
                            if cmd_args.single_sample is not None and cmd_args.single_sample == cur_sample_idx:
                                self.single_sample_test.append(
                                    len(self.pg_list))
                        graph_json = json.loads(x[0])
                        self.pg_list.append(ProgramGraph(graph_json))
                        self.ordered_pre_post.append(x[1])
                cur_sample_idx += 1
        if cmd_args.single_sample is not None:
            assert len(self.single_sample_test) and len(
                self.single_sample_train)
            self.train_indices = self.single_sample_train
            self.test_indices = self.single_sample_test

        self.build_node_type_dict()

        for i in range(len(self.pg_list)):
            g = self.pg_list[i]
            self.sample_graphs.append(
                GraphSample(i, self, g, self.node_type_dict))

        if cmd_args.phase == 'train':
            self.sample_idxes = self.train_indices
        else:
            self.sample_idxes = self.test_indices

        random.shuffle(self.sample_idxes)
        self.sample_pos = 0
示例#2
0
 def load_pg_list(self, fname):
     with open(cmd_args.data_root + '/graph/' + fname + '.json', 'r') as gf:
         graph_json = json.load(gf)
         self.pg_list.append(ProgramGraph(graph_json))
示例#3
0
            lv += 1

        return cur_message_layer

if __name__ == '__main__':
    random.seed(cmd_args.seed)
    np.random.seed(cmd_args.seed)
    torch.manual_seed(cmd_args.seed)    

    s2v_graphs = []
    pg_graphs = []
    with open(cmd_args.data_root + '/list.txt', 'r') as f:
        for row in f:            
            with open(cmd_args.data_root + '/' + row.strip() + '.json', 'r') as gf:
                graph_json = json.load(gf)
                pg_graphs.append(ProgramGraph(graph_json))
    for g in pg_graphs:
        s2v_graphs.append( S2VGraph(g) )
    
    print(len(s2v_graphs))
    # mf = EmbedMeanField(128, len(node_type_dict))
    if cmd_args.ctx == 'gpu':
        mf = mf.cuda()

    embedding = mf(s2v_graphs[0:2])
    embed2 = mf(s2v_graphs[0:1])
    embed3 = mf(s2v_graphs[1:2])
    ee = torch.cat([embed2, embed3], dim=0)
    diff = torch.sum(torch.abs(embedding - ee))
    print(diff)
示例#4
0
        self.vc_list = vc_list


if __name__ == '__main__':
    random.seed(cmd_args.seed)
    np.random.seed(cmd_args.seed)
    torch.manual_seed(cmd_args.seed)
    tic()
    params = []

    graph = None
    node_type_dict = {}
    vc_list = []

    with open(cmd_args.input_graph, 'r') as graph_file:
        graph = ProgramGraph(json.load(graph_file))
        for node in graph.node_list:
            if not node.node_type in node_type_dict:
                v = len(node_type_dict)
                node_type_dict[node.node_type] = v

    if graph is not None:
        if cmd_args.encoder_model == 'GNN':
            encoder = EmbedMeanField(cmd_args.embedding_size,
                                     len(node_type_dict),
                                     max_lv=cmd_args.s2v_level)
        elif cmd_args.encoder_model == 'LSTM':
            encoder = LSTMEmbed(cmd_args.embedding_size, len(node_type_dict))
        elif cmd_args.encoder_model == 'Param':
            g_list = GraphSample(graph, vc_list, node_type_dict)
            encoder = ParamEmbed(cmd_args.embedding_size,