#print(len(Y_test))


model = Graph()
model.input(shape=(img_channels, img_rows, img_cols), name='input1')
model.node(Convolution2D(32, 7, 7), name='conv11', input='input1', activation='relu')
model.node(Convolution2D(32, 3, 3), name='conv12', input='conv11', activation='relu')
model.node(MaxPooling2D(poolsize=(2, 2)), name='pool1', input='conv12')

model.node(Convolution2D(32, 1, 1), name='conv21', input='pool1', activation='relu')
model.node(Convolution2D(32, 1, 1), name='conv22', input='pool1', activation='relu')
model.node(Convolution2D(32, 1, 1), name='conv23', input='pool1', activation='relu')
model.node(MaxPooling2D(poolsize=(2, 2)), name='pool21', input='pool1')

model.node(Convolution2D(32, 3, 3), name='conv31', input='conv22', activation='relu')
model.node(Convolution2D(32, 5, 5), name='conv32', input='conv23', activation='relu')
model.node(Convolution2D(32, 1, 1), name='conv33', input='pool21', activation='relu')

X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255


sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)

model.output(inputs=['conv31', 'conv32', 'conv33'], name='output1', merge_mode='concat')

model.compile(loss_merge='sum', optimizer='sgd')
model.fit(train={'input1':X_train, 'output1':Y_train})