def __call__(self, x, reuse=False): with tf.variable_scope(self.name) as scope: if reuse: scope.reuse_variables() ################# # Generator ################# # Initial dense multiplication x = layers.linear(x, 512 * 8 * 8) # Reshape to image format if FLAGS.data_format == "NCHW": target_shape = (-1, 512, 8, 8) else: target_shape = (-1, 8, 8, 512) x = layers.reshape(x, target_shape) x = tf.contrib.layers.batch_norm(x, fused=True) x = tf.nn.elu(x) # Conv2D + Phase shift blocks x = layers.conv2d_block(x, "G16_conv2D_1", 256, 3, 1, data_format=FLAGS.data_format) x = layers.conv2d_block(x, "G16_conv2D_2", 256, 3, 1, data_format=FLAGS.data_format) x = layers.phase_shift(x, upsampling_factor=2, name="PS_G16", data_format=FLAGS.data_format) x = layers.conv2d_block(x, "G16_conv2D_3", FLAGS.channels, 3, 1, bn=False, activation_fn=None, data_format=FLAGS.data_format) x = tf.nn.tanh(x, name="x_G16") return x
def __call__(self, x, reuse=False): with tf.variable_scope(self.name) as scope: if reuse: scope.reuse_variables() x = layers.conv2d_block(x, "D64_conv2D_1", 32, 3, 2, data_format=FLAGS.data_format, bn=False) x = layers.conv2d_block(x, "D64_conv2D_2", 64, 3, 2, data_format=FLAGS.data_format) x = layers.conv2d_block(x, "D64_conv2D_3", 128, 3, 2, data_format=FLAGS.data_format) x = layers.conv2d_block(x, "D64_conv2D_4", 256, 3, 2, data_format=FLAGS.data_format) x_shape = x.get_shape().as_list() flat_dim = 1 for d in x_shape[1:]: flat_dim *= d target_shape = (-1, flat_dim) x = layers.reshape(x, target_shape) x_mbd = layers.mini_batch_disc(x, num_kernels=100, dim_per_kernel=5, name="mbd64") x = tf.concat([x, x_mbd], axis=1) x = layers.linear(x, 1) return x
def __call__(self, x, x_feat, reuse=False): with tf.variable_scope(self.name) as scope: if reuse: scope.reuse_variables() ################# # Generator ################# # Add x_feat up = x.get_shape().as_list()[2] / x_feat.get_shape().as_list()[2] x_feat = layers.phase_shift(x_feat, upsampling_factor=up, name="PS_G32_feat", data_format=FLAGS.data_format) if FLAGS.data_format == "NCHW": x = tf.concat([x, x_feat], axis=1) else: x = tf.concat([x, x_feat], axis=-1) x = layers.conv2d_block(x, "G32_conv2D_1", 256, 3, 1, data_format=FLAGS.data_format) x = layers.conv2d_block(x, "G32_conv2D_2", 256, 3, 1, data_format=FLAGS.data_format) x = layers.phase_shift(x, upsampling_factor=2, name="PS_G32", data_format=FLAGS.data_format) x = layers.conv2d_block(x, "G32_conv2D_3", FLAGS.channels, 3, 1, bn=False, activation_fn=None, data_format=FLAGS.data_format) x = tf.nn.tanh(x, name="x_G32") return x
def __call__(self, x, reuse=False): with tf.variable_scope(self.name) as scope: if reuse: scope.reuse_variables() M, N = x.get_shape().as_list()[-2:] x = scattering.Scattering(M=M, N=N, J=2)(x) x = tf.contrib.layers.batch_norm(x, data_format=FLAGS.data_format, fused=True, scope="scat_bn") x = layers.conv2d_block("CONV2D", x, 64, 1, 1, p="SAME", data_format=FLAGS.data_format, bias=True, bn=False, activation_fn=tf.nn.relu) target_shape = (-1, 64 * 7 * 7) x = layers.reshape(x, target_shape) x = layers.linear(x, 512, name="dense1") x = tf.nn.relu(x) x = layers.linear(x, 10, name="dense2") return x
def __call__(self, x, reuse=False): with tf.variable_scope(self.name) as scope: if reuse: scope.reuse_variables() for idx, (f, k, s, p) in enumerate(zip(self.list_filters, self.list_kernel_size, self.list_strides, self.list_padding)): if idx == 0: bn = False else: bn = True name = "conv2D_%s" % idx x = layers.conv2d_block(name, x, f, k, s, p=p, stddev=0.02, data_format=self.data_format, bias=True, bn=bn, activation_fn=layers.lrelu) target_shape = (self.batch_size, -1) x = layers.reshape(x, target_shape) # # Add MBD # x_mbd = layers.mini_batch_disc(x, num_kernels=100, dim_per_kernel=5) # # Concat # x = tf.concat([x, x_mbd], axis=1) x = layers.linear(x, 1) return x
def __call__(self, x, reuse=False, mode="D"): with tf.variable_scope(self.name) as scope: if reuse: scope.reuse_variables() x = layers.conv2d_block(x, "D16_conv2D_1", 32, 3, 2, data_format=FLAGS.data_format, bn=False) x = layers.conv2d_block(x, "D16_conv2D_2", 16, 3, 2, data_format=FLAGS.data_format) x_feat = tf.identity(x, "x_feat16") x_shape = x.get_shape().as_list() flat_dim = 1 for d in x_shape[1:]: flat_dim *= d target_shape = (-1, flat_dim) x = layers.reshape(x, target_shape) x = layers.linear(x, 1) x_mbd = layers.mini_batch_disc(x, num_kernels=100, dim_per_kernel=5, name="mbd16") x = tf.concat([x, x_mbd], axis=1) if mode == "D": return x else: return x_feat, x
def __init__(self, ob_shape, ac_shape, reuse=False, **kwargs): self.sess = tf.get_default_session() nbatch, nenvs = kwargs.values() obs_ph = tf.placeholder(tf.uint8, [nbatch,*ob_shape], 'obs_ph') with tf.variable_scope('model', reuse=reuse): if obs_ph.dtype != tf.float32: x = tf.cast(obs_ph, tf.float32) / 255. with tf.variable_scope('cnn'): h = layers.conv2d_block(x) with tf.variable_scope('actor'): logits = layers.fc(h, ac_shape[-1], 'logits', activate=False, gain=0.01) with tf.variable_scope('critic'): vf = layers.fc(h, 1, 'vf', activate=False, gain=1.0)[:,0] def sample(): u = tf.random_uniform(tf.shape(logits)) return tf.argmax(logits - tf.log(-tf.log(u)), axis=-1) def entropy(): a0 = logits - tf.reduce_max(logits, axis=-1, keepdims=True) ea0 = tf.exp(a0) z0 = tf.reduce_sum(ea0, axis=-1, keepdims=True) p0 = ea0 / z0 return tf.reduce_sum(p0 * (tf.log(z0) - a0), axis=-1) def kl(other): a0 = logits - tf.reduce_max(logits, axis=-1, keep_dims=True) a1 = other.logits - tf.reduce_max(other.logits, axis=-1, keep_dims=True) ea0 = tf.exp(a0) ea1 = tf.exp(a1) z0 = tf.reduce_sum(ea0, axis=-1, keepdims=True) z1 = tf.reduce_sum(ea1, axis=-1, keepdims=True) p0 = ea0 / z0 return tf.reduce_sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), axis=-1) def neglogp(x): one_hot_actions = tf.one_hot(x, logits.get_shape().as_list()[-1]) return tf.nn.softmax_cross_entropy_with_logits_v2( logits=logits, labels=one_hot_actions) self.a0 = sample() self.neglogp0 = neglogp(self.a0) self.vf0 = vf self.entropy = entropy self.kl = kl self.neglogp = neglogp self.obs_ph = obs_ph self.initial_state = None
def __call__(self, x, reuse=False): with tf.variable_scope(self.name) as scope: if reuse: scope.reuse_variables() x = tf.contrib.layers.batch_norm(x, data_format=self.data_format, fused=True, scope="scat_bn") x = layers.conv2d_block("D_conv2D1", x, 32, 3, 2, p="SAME", data_format=self.data_format, bias=False, bn=False, activation_fn=layers.lrelu) x = layers.conv2d_block("D_conv2D2", x, 64, 3, 2, p="SAME", data_format=self.data_format, bias=False, bn=True, activation_fn=layers.lrelu) # x = layers.conv2d_block("D_conv2D3", x, 128, 3, 2, p="SAME", data_format=self.data_format, bias=True, bn=True, activation_fn=layers.lrelu) x_shape = x.get_shape().as_list() target_shape = (-1, x_shape[-1] * x_shape[-2] * x_shape[-3]) x = layers.reshape(x, target_shape) x = layers.linear(x, 1, name='dense2') return x
def __init__(self, ob_shape, ac_shape, reuse=False, nlstm=256, **kwargs): self.sess = tf.get_default_session() nbatch, nenvs = kwargs.values() obs_ph = tf.placeholder(tf.uint8, [nbatch, *ob_shape], 'obs') # obs mask_ph = tf.placeholder(tf.float32, [nbatch], 'masks') # mask (done t-1) """Share LSTM-Layer policy fn and value fn""" state_ph = tf.placeholder(tf.float32, [nenvs, nlstm*2], 'states') # states with tf.variable_scope('model', reuse=reuse): # cast observations to float if obs_ph.dtype != tf.float32: x = tf.cast(obs_ph, tf.float32) / 255. # 3 layer conv2d + 2 layer fc h = layers.conv2d_block(x) # split h and mask into sentential list xs = layers.batch_to_seq(h, nenvs, nbatch//nenvs) ms = layers.batch_to_seq(mask_ph, 1, nbatch) # long short-term memory layer h6, snew = layers.lstm(xs, ms, state_ph, 'lstm', nh=nlstm, use_ln=False, reuse=reuse) # reshape output to batch-tensor again h6 = layers.seq_to_batch(h6) with tf.variable_scope('actor'): logits = layers.fc(h6, ac_shape[-1], 'logits', activate=False, gain=0.01) with tf.variable_scope('critic'): vf = layers.fc(h6, 1, 'vf', activate=False, gain=1.0)[:,0] self.snew = snew self.initial_state = np.zeros((nbatch, nlstm*2), dtype=np.float32) def sample(): u = tf.random_uniform(tf.shape(logits)) return tf.argmax(logits - tf.log(-tf.log(u)), axis=-1) def entropy(): a0 = logits - tf.reduce_max(logits, axis=-1, keepdims=True) ea0 = tf.exp(a0) z0 = tf.reduce_sum(ea0, axis=-1, keepdims=True) p0 = ea0 / z0 return tf.reduce_sum(p0 * (tf.log(z0) - a0), axis=-1) def kl(other): a0 = logits - tf.reduce_max(logits, axis=-1, keep_dims=True) a1 = other.logits - tf.reduce_max(other.logits, axis=-1, keep_dims=True) ea0 = tf.exp(a0) ea1 = tf.exp(a1) z0 = tf.reduce_sum(ea0, axis=-1, keepdims=True) z1 = tf.reduce_sum(ea1, axis=-1, keepdims=True) p0 = ea0 / z0 return tf.reduce_sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), axis=-1) def neglogp(x): one_hot_actions = tf.one_hot(x, logits.get_shape().as_list()[-1]) return tf.nn.softmax_cross_entropy_with_logits_v2( logits=logits, labels=one_hot_actions) self.a0 = sample() self.neglogp0 = neglogp(self.a0) self.vf0 = vf self.entropy = entropy self.kl = kl self.neglogp = neglogp self.obs_ph = obs_ph self.mask_ph = mask_ph self.state_ph = state_ph