def forward(self, inp, weights, attention_mask, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False): if FLAGS.antialias: # Antialias the image to smooth local perturbations antialias = tf.tile(stride_3, (1, 1, tf.shape(inp)[3], tf.shape(inp)[3])) inp = tf.nn.conv2d(inp, antialias, [1, 2, 2, 1], padding='SAME') channels = self.channels if FLAGS.augment_vis: for transform in standard_transforms: inp = transform(inp) batch_size = tf.shape(inp)[0] if FLAGS.comb_mask: attention_mask = tf.nn.softmax(attention_mask) inp = tf.reshape(tf.transpose(inp, (0, 3, 1, 2)), (tf.shape(inp)[0], channels, self.img_size, self.img_size, 1)) attention_mask = tf.reshape(attention_mask, (tf.shape(attention_mask)[0], 1, self.img_size, self.img_size, FLAGS.cond_func)) inp = tf.reshape(tf.transpose(inp * attention_mask, (0, 4, 1, 2, 3)), (tf.shape(inp)[0] * FLAGS.cond_func, 64, 64, channels)) weights = weights.copy() if not FLAGS.cclass: label = None if stop_grad: for k, v in weights.items(): if type(v) == dict: v = v.copy() weights[k] = v for k_sub, v_sub in v.items(): v[k_sub] = tf.stop_gradient(v_sub) else: weights[k] = tf.stop_gradient(v) if FLAGS.swish_act: act = swish else: act = tf.nn.leaky_relu inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False, activation=act) hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', adaptive=False, label=label, act=act) hidden2 = smart_res_block(hidden1, weights, reuse, 'res_1', stop_batch=stop_batch, adaptive=True, label=label, act=act) hidden3 = smart_res_block(hidden2, weights, reuse, 'res_2', stop_batch=stop_batch, adaptive=False, label=label, act=act) hidden4 = smart_res_block(hidden3, weights, reuse, 'res_3', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, act=act) hidden5 = smart_res_block(hidden4, weights, reuse, 'res_4', stop_batch=stop_batch, adaptive=True, label=label, act=act) hidden5 = act(hidden5) hidden6 = tf.reduce_mean(hidden5, axis=[1, 2]) energy = smart_fc_block(hidden6, weights, reuse, 'fc5') if FLAGS.comb_mask: energy = tf.reduce_sum(tf.reshape(energy, (batch_size, FLAGS.cond_func)), axis=1, keepdims=True) return energy
def forward(self, inp, weights, attention_mask=None, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False): if FLAGS.swish_act: act = swish else: act = tf.nn.leaky_relu hidden1 = act(smart_fc_block(inp, weights, reuse, 'fc_dense')) hidden1 = tf.reshape(hidden1, (tf.shape(inp)[0], 4, 4, 4*self.dim_hidden)) hidden2 = smart_res_block(hidden1, weights, reuse, 'res_1', label=label, act=act, upsample=True) hidden3 = smart_res_block(hidden2, weights, reuse, 'res_2', adaptive=False, label=label, act=act, upsample=True) hidden4 = smart_res_block(hidden3, weights, reuse, 'res_3', label=label, act=act, upsample=True) hidden5 = smart_res_block(hidden4, weights, reuse, 'res_4', label=label, adaptive=False, act=act, upsample=True) output = smart_conv_block(hidden5, weights, reuse, 'c4_out', use_stride=False, activation=None) return output
def forward(self, inp, weights, attention_mask=None, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False, latent=None): weights = weights.copy() batch = tf.shape(inp)[0] if not FLAGS.cclass: label = None if stop_grad: for k, v in weights.items(): if type(v) == dict: v = v.copy() weights[k] = v for k_sub, v_sub in v.items(): v[k_sub] = tf.stop_gradient(v_sub) else: weights[k] = tf.stop_gradient(v) if FLAGS.swish_act: act = swish else: act = tf.nn.leaky_relu dropout = self.dropout train = self.train # Make sure gradients are modified a bit inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False, activation=act) hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', label=label, dropout=dropout, train=train, downsample=True, adaptive=False) if FLAGS.use_attention: hidden1 = smart_atten_block(hidden1, weights, reuse, 'atten', stop_at_grad=stop_at_grad) hidden2 = smart_res_block(hidden1, weights, reuse, 'res_3', stop_batch=stop_batch, downsample=True, adaptive=True, label=label, dropout=dropout, train=train, act=act) hidden3 = smart_res_block(hidden2, weights, reuse, 'res_5', stop_batch=stop_batch, downsample=True, adaptive=True, label=label, dropout=dropout, train=train, act=act) hidden4 = smart_res_block(hidden3, weights, reuse, 'res_7', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act, downsample=True, adaptive=True) hidden5 = smart_res_block(hidden4, weights, reuse, 'res_9', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act, downsample=True, adaptive=False) hidden6 = smart_res_block(hidden5, weights, reuse, 'res_10', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act, downsample=False, adaptive=False) if FLAGS.swish_act: hidden6 = act(hidden6) else: hidden6 = tf.nn.relu(hidden6) hidden5 = tf.reduce_sum(hidden6, [1, 2]) hidden6 = smart_fc_block(hidden5, weights, reuse, 'fc5') energy = hidden6 return energy