def gen(Z, Y, w, w2, w3, wx): yb = Y.dimshuffle(0, 1, 'x', 'x') Z = T.concatenate([Z, Y], axis=1) h = relu(batchnorm(T.dot(Z, w))) h = T.concatenate([h, Y], axis=1) h2 = relu(batchnorm(T.dot(h, w2))) h2 = h2.reshape((h2.shape[0], ngf*2, npx_, npx_)) h2 = conv_cond_concat(h2, yb) h3 = relu(batchnorm(deconv(h2, w3, subsample=(2, 2), border_mode=(2, 2)))) h3 = conv_cond_concat(h3, yb) x = sigmoid(deconv(h3, wx, subsample=(2, 2), border_mode=(2, 2))) return x
def discrim(X, Y, w, w2, w3, wy): yb = Y.dimshuffle(0, 1, 'x', 'x') X = conv_cond_concat(X, yb) h = lrelu(dnn_conv(X, w, subsample=(2, 2), border_mode=(2, 2))) h = conv_cond_concat(h, yb) h2 = lrelu(batchnorm(dnn_conv(h, w2, subsample=(2, 2), border_mode=(2, 2)))) h2 = T.flatten(h2, 2) h2 = T.concatenate([h2, Y], axis=1) h3 = lrelu(batchnorm(T.dot(h2, w3))) h3 = T.concatenate([h3, Y], axis=1) y = sigmoid(T.dot(h3, wy)) return y
def gen(Z, Y, w, w2, w3, wx): yb = Y.dimshuffle(0, 1, 'x', 'x') Z = T.concatenate([Z, Y], axis=1) h = relu(batchnorm(T.dot(Z, w))) h = T.concatenate([h, Y], axis=1) h2 = relu(batchnorm(T.dot(h, w2))) h2 = h2.reshape((h2.shape[0], ngf * 2, temp, temp)) h2 = conv_cond_concat(h2, yb) h3 = relu(batchnorm(deconv(h2, w3, subsample=(2, 2), border_mode=(2, 2)))) h3 = conv_cond_concat(h3, yb) x = sigmoid(deconv(h3, wx, subsample=(2, 2), border_mode=(2, 2))) return x
def gen(Z, Y): yb = Y.dimshuffle(0, 1, 'x', 'x') Z = T.concatenate([Z, Y], axis=1) h = relu(batchnorm(T.dot(Z, gw), g=gg, b=gb)) h = h.reshape((h.shape[0], ngf * 4, 4, 4)) h = conv_cond_concat(h, yb) h2 = relu( batchnorm(deconv(h, gw2, subsample=(2, 2), border_mode=(2, 2)), g=gg2, b=gb2)) h2 = conv_cond_concat(h2, yb) h3 = relu( batchnorm(deconv(h2, gw3, subsample=(2, 2), border_mode=(2, 2)), g=gg3, b=gb3)) h3 = conv_cond_concat(h3, yb) x = tanh(deconv(h3, gw4, subsample=(2, 2), border_mode=(2, 2))) return x
def gen(Z, Y, w, w2, w3, wx): print '\n@@@@ gen()' printVal('Z', Z) # matrix #printVal( 'Y', Y ) # matrix printVal('w', w) # matrix printVal('w2', w2) # matrix printVal('w3', w3) # tensor printVal('wx', wx) # tensor # Yの要素の並びの入れ替え。数字の引数は、次元番号。'x' は ブロードキャスト # 並び替えの前後で、全体の要素数は変わらない。 yb = Y.dimshuffle(0, 1, 'x', 'x') # yb は4次元テンソル #printVal('yb', yb) # 行列 Z と Y を結合(横方向) Z = T.concatenate([Z, Y], axis=1) # matrix # Z*w(Full Connect) をバッチ正規化して、ReLU 適用 tmp_a = T.dot(Z, w) # dot(matrix, matrix)->matrix printVal('dot(Z,w) -> tmp_a', tmp_a) h = relu(batchnorm(T.dot(Z, w))) #CCC h = T.concatenate([h, Y], axis=1) #CCC printVal('h', h) # matrix h2 = relu(batchnorm(T.dot(h, w2))) #CCC printVal('h2', h2) #h2:matrix h2r = h2.reshape((h2.shape[0], GEN_NUM_FILTER * 2, 7, 7)) #CCC printVal('h2r', h2r) #h2r:tensor h2ry = conv_cond_concat(h2r, yb) # printVal('h2ry', h2ry) #h2:tensor # デコンボリューション:論文によれば、空間プーリングの代わりに適用する d = deconv(h2ry, w3, subsample=(2, 2), border_mode=(2, 2)) printVal('d', d) #h3 = relu(batchnorm(deconv(h2, w3, subsample=(2, 2), border_mode=(2, 2)))) h3 = relu(batchnorm(d)) h3 = conv_cond_concat(h3, yb) x = sigmoid(deconv(h3, wx, subsample=(2, 2), border_mode=(2, 2))) return x, h2
def gen(Z, Y, w, w2, w3, wx): #Z: (nbatch, nz) = (128, 100) #Y: (nbatch, ny) = (128, 10) #w: (nz+ny, ngfc) = (110, 1024) #w2: (ngfc+ny, ngf*2*7*7) = (1024+10, 64*2*7*7) = (1034, 6272) #w3: (ngf*2+ny, ngf, 5, 5) = (128+10, 64, 5, 5 ) = (138, 64, 5, 5) #wx: (ngf+ny, nc, 5, 5) = (64+10, 1, 5, 5) = (74, 1, 5, 5) print '\n@@@@ gen()' printVal( 'Y', Y ) printVal( 'w', w ) #matrix printVal( 'w2', w2 ) #matrix printVal( 'w3', w3 ) #tensor printVal( 'wx', wx ) #tensor # Yの要素の並びの入れ替え。数字の引数は、次元番号。'x' は ブロードキャスト #(G1) yb = Y.dimshuffle(0, 1, 'x', 'x') # yb は4次元テンソル printVal('yb', yb) # 行列 Z と Y を結合(横方向):いわゆる Conditional GAN の形にする。 #(G2) Z = T.concatenate([Z, Y], axis=1) # Z: (128, 110) #(G3) # Z*w をバッチ正規化して、ReLU 適用 t1 = T.dot(Z, w) #full connect : t1: (128, 1024) printVal('t1', t1) h = relu( batchnorm( t1 ) ) # h: (128, 1024) #(G4) h = T.concatenate([h, Y], axis=1) # h: (128, 1034) #(G5) h2 = relu( batchnorm( T.dot(h, w2) #full connect ) ) #(G6) h2 = h2.reshape((h2.shape[0], ngf*2, 7, 7)) #(G7) h3 = conv_cond_concat(h2, yb) #XXX printVal( 'h2', h2 ) #(G8)デコンボリューション:論文によれば、空間プーリングの代わりに適用する d = deconv(h3, w3, subsample=(2, 2), border_mode=(2, 2)) #h3 = relu(batchnorm(deconv(h2, w3, subsample=(2, 2), border_mode=(2, 2)))) #(G9) h4 = relu( batchnorm(d) ) #(G10) h5 = conv_cond_concat(h4, yb) #(G11) x = sigmoid( deconv(h5, wx, subsample=(2, 2), border_mode=(2, 2) ) ) return x