/
matrix_multiply.py
61 lines (30 loc) · 1.03 KB
/
matrix_multiply.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from pyspark import SparkConf, SparkContext
import sys, operator
def add_tuples(a, b):
return list(sum(p) for p in zip(a,b))
def permutation(row):
rowPermutation = []
for element in row:
for e in range(len(row)):
rowPermutation.append(float(element) * float(row[e]))
return rowPermutation
def main():
input = sys.argv[1]
output = sys.argv[2]
conf = SparkConf().setAppName('Matrix Multiplication')
sc = SparkContext(conf=conf)
assert sc.version >= '1.5.1'
row = sc.textFile(input).map(lambda row : row.split(' ')).cache()
ncol = len(row.take(1)[0])
intermediateResult = row.map(permutation).reduce(add_tuples)
outputFile = open(output, 'w')
result = [intermediateResult[x:x+3] for x in range(0, len(intermediateResult), ncol)]
for row in result:
for element in row:
outputFile.write(str(element) + ' ')
outputFile.write('\n')
outputFile.close()
# outputResult = sc.parallelize(result).coalesce(1)
# outputResult.saveAsTextFile(output)
if __name__ == "__main__":
main()