obj,
            open(
                data_dir + 'data_2020715/node2vec/models/' + self.city +
                '_distmult.pkl', 'wb'))


if __name__ == "__main__":
    data_dir = 'E:/python-workspace/CityRoadPrediction/'
    train = DataLoader(data_dir + 'data_2020715/train/')
    test = DataLoader(data_dir + 'data_2020715/test/')

    cities = set(train.cities) & set(test.cities)
    cities = sorted(list(cities))
    for city in cities:
        print(city)
        train.initialize()
        train.load_dir_datas(city)
        test.initialize()
        test.load_dir_datas(city)
        tester = VecTester(embed_dim=50,
                           test_data=test,
                           city=city,
                           data_dir=data_dir + 'data_2020715/node2vec/')
        trainer = Node2VecTrainer(embed_dim=50,
                                  train_data=train,
                                  city=city,
                                  tester=tester)
        trainer.prepare_train_embedding(data_dir + 'data_2020715/node2vec/')
        #trainer.train_distmult(data_dir=data_dir + 'data_2020715/node2vec/',
        #                       result_dir=data_dir + 'data_20200715/node2vec/result/')
                    sample = self.embedding[ids][i + j]
                    edge = {'start': int(sample['start_id']), 'end': int(sample['end_id']),
                            'score': float(output[j][1]), 'target': sample['target']}
                    cand_edges.append(edge)
            cand_edges.sort(key=lambda e: e['score'], reverse=True)
            test_result[ids] = cand_edges
            for edge in cand_edges:
                if edge['score'] < np.log(0.5):
                    break
                #if is_valid(edge, existed_edges, self.id2node):
                existed_edges.append(edge)
                if {'start': edge['start'], 'end': edge['end']} in self.test_loader[ids]['target_edges'] or \
                        {'start': edge['end'], 'end': edge['start']} in self.test_loader[ids]['target_edges']:
                    right += 1
                else:
                    wrong += 1

            total += len(self.test_loader[ids]['target_edges'])
        precision = right / (right + wrong + 1e-9)
        recall = right / (total + 1e-9)
        f1 = 2 * precision * recall / (precision + recall + 1e-9)
        pickle.dump(test_result, open(result_dir + self.city + '_result.pkl', 'wb'))
        return right, wrong, total, precision, recall, f1


if __name__ == "__main__":
    test = DataLoader('E:/python-workspace/CityRoadPrediction/data_20200610/test/')
    test.load_dir_datas('Akron')
    tester = VecTester(embed_dim=50, test_data=test, city='Akron')
    print(test[0]['source_edges'])