cname = "out/group-%d.matrix" % gid

try:
    MPI.File.Delete(cname)
except MPI.Exception:
    pass

cfile = MPI.File.Open(scomm, cname, MPI.MODE_WRONLY + MPI.MODE_CREATE)

start = time.time()

indexes = np.array_split(range(num_columns * num_rows), scomm.size)
indexes = scomm.scatter(indexes)

c_elements = matrix.mul_matrix_partial(a,b,indexes[0], indexes[-1] + 1, num_columns)
#
computations = StringIO()

index = indexes[0]
for element in c_elements:
    if (index) and ((index % num_columns) == 0):
        computations.write("\n")
    computations.write("%s " % element)
    index += 1
s = computations.getvalue()
cfile.Write_ordered(s)
ans = scomm.gather(c_elements)

if scomm.rank == 0:
    i = 0
b = np.arange(ARRAY_SIZE, 2 * ARRAY_SIZE).reshape(ARRAY_DIM, ARRAY_DIM)
c = np.zeros((ARRAY_DIM, ARRAY_DIM), dtype=np.int32)


if rank == 0:
    # main process

    c_parts = []
    reqs = []
    for i in range(size-1):
        c_elements = np.empty(ARRAY_SIZE, dtype=np.int32)
        req = comm.Irecv(c_elements ,source=i+1, tag=10)
        reqs.append(req)
        c_parts.append(c_elements)
    start, stop = matrix.get_rank_indexes(rank,size, ARRAY_DIM)
    c_elements = matrix.mul_matrix_partial(a, b, start, stop, ARRAY_DIM)
    matrix.copy_matrix_slice(c, c_elements, start, stop, ARRAY_DIM)
    MPI.Request.Waitall(reqs)
    for i in range(size-1):
        start, stop = matrix.get_rank_indexes(i+1, size, ARRAY_DIM)
        matrix.copy_matrix_slice(c, c_parts[i], start, stop, ARRAY_DIM)

    print "result matrix\n" ,c
    print "true result\n", np.dot(a,b)

else:
    start, stop = matrix.get_rank_indexes(rank,size, ARRAY_DIM)
    c_elements = matrix.mul_matrix_partial(a, b, start, stop, ARRAY_DIM)
    c_elements = np.array(c_elements, dtype=np.int32)
    comm.Isend(c_elements, dest=0,tag=10)
groups = list(np.array_split(range(size),  args.groups))
dist =  {}

for gid, ranks in enumerate(groups):
    sgroup = MPI.Group.Incl(comm.Get_group(), ranks)
    dist.update({rank: (gid, sgroup) for rank in ranks})


gid, sgroup = dist[comm.rank]
scomm = comm.Create(sgroup)

start  = time.time()

indexes = np.array_split(range(ARRAY_DIM * ARRAY_DIM), scomm.size)
indexes = scomm.scatter(indexes)
c_elements = matrix.mul_matrix_partial(a,b,indexes[0], indexes[-1] + 1, ARRAY_DIM)
ans = scomm.gather(c_elements)
if scomm.rank == 0:
    i = 0
    for elements in ans:
        matrix.copy_matrix_slice(c, elements, i, i + len(elements), ARRAY_DIM)
        i += len(elements)
scomm.barrier()
end = time.time()
times = scomm.allgather(end-start)

print "Group %s, Global rank %s, Rank in group %s, Group Size: %s, time: %s" %(gid ,rank, scomm.rank, scomm.size, max(times))
comm.barrier()
if scomm.rank == 0:
    print c