def filterbank(img_h, img_w, att_dim, center_x, center_y, delta, var): """ Get the filterbank matrix. :param img_h: image height :param img_w: image width :param att_dim: attention dim, the attention window is att_dim x att_dim :param center_x: attention center (x-axis): batch_size x 1 :param center_y: attention center (y-axis): batch_size x 1 :param delta: stride: batch_size x 1 :param var: variance (sigma^2): batch_size x 1 :return: filter_x, filter_y """ with tf.name_scope('filterbank'): rng = O.range(1, att_dim+1, dtype='float32') - att_dim / 2 + 0.5 mu_x = center_x + rng * delta mu_y = center_y + rng * delta all_x = O.range(1, img_w+1, dtype='float32') all_y = O.range(1, img_h+1, dtype='float32') a = all_x - mu_x.add_axis(-1) b = all_y - mu_y.add_axis(-1) fx = O.exp(-O.sqr(a) / var.add_axis(-1) / 2.) fy = O.exp(-O.sqr(b) / var.add_axis(-1) / 2.) fx /= fx.sum(axis=2, keepdims=True) + 1e-8 fy /= fy.sum(axis=2, keepdims=True) + 1e-8 fx = O.as_varnode(fx) return fx, fy
def split_att_params(img_h, img_w, att_dim, value): """ Split attention params. :param img_h: image height :param img_w: image width :param att_dim: attention dim, the attention window is att_dim x att_dim :param value: attention params produced by hidden layer: batch_size, 5 :return: center_x, center_y, delta, variance, gamma """ with tf.name_scope('split_att_params'): center_x, center_y, log_delta, log_var, log_gamma = O.split(value, 5, axis=1) delta, var, gamma = map(lambda x: O.exp(x).reshape(-1, 1), [log_delta, log_var, log_gamma]) center_x = ((img_w + 1.) * (center_x + 1.) / 2.).reshape(-1, 1) center_y = ((img_h + 1.) * (center_y + 1.) / 2.).reshape(-1, 1) delta *= float((max(img_h, img_w) - 1) / (att_dim - 1)) return center_x, center_y, delta, var, gamma
def forward(x): if is_reconstruct or env.phase is env.Phase.TRAIN: with env.variable_scope('encoder'): _ = x _ = O.fc('fc1', _, 500, nonlin=O.tanh) _ = O.fc('fc2', _, 500, nonlin=O.tanh) mu = O.fc('fc3_mu', _, code_length) log_var = O.fc('fc3_sigma', _, code_length) var = O.exp(log_var) std = O.sqrt(var) epsilon = O.random_normal([x.shape[0], code_length]) z_given_x = mu + std * epsilon else: z_given_x = O.random_normal([1, code_length]) with env.variable_scope('decoder'): _ = z_given_x _ = O.fc('fc1', _, 500, nonlin=O.tanh) _ = O.fc('fc2', _, 500, nonlin=O.tanh) _ = O.fc('fc3', _, 784, nonlin=O.sigmoid) _ = _.reshape(-1, h, w, c) x_given_z = _ if env.phase is env.Phase.TRAIN: with env.variable_scope('loss'): content_loss = O.raw_cross_entropy_prob( 'raw_content', x_given_z.flatten2(), x.flatten2()) content_loss = content_loss.sum(axis=1).mean( name='content') # distrib_loss = 0.5 * (O.sqr(mu) + O.sqr(std) - 2. * O.log(std + 1e-8) - 1.0).sum(axis=1) distrib_loss = -0.5 * (1. + log_var - O.sqr(mu) - var).sum(axis=1) distrib_loss = distrib_loss.mean(name='distrib') loss = content_loss + distrib_loss dpc.add_output(loss, name='loss', reduce_method='sum') dpc.add_output(x_given_z, name='output')
def make_rpredictor_network(env): is_train = env.phase is env.Phase.TRAIN with env.create_network() as net: h, w, c = get_input_shape() # Hack(MJY):: forced RGB input (instead of combination of history frames) c = 3 dpc = env.create_dpcontroller() with dpc.activate(): def inputs(): state = O.placeholder('state', shape=(None, h, w, c)) t1_state = O.placeholder('t1_state', shape=(None, h, w, c)) t2_state = O.placeholder('t2_state', shape=(None, h, w, c)) return [state, t1_state, t2_state] @O.auto_reuse def forward_conv(x): _ = x / 255.0 with O.argscope(O.conv2d, nonlin=O.relu): _ = O.conv2d('conv0', _, 32, 5) _ = O.max_pooling2d('pool0', _, 2) _ = O.conv2d('conv1', _, 32, 5) _ = O.max_pooling2d('pool1', _, 2) _ = O.conv2d('conv2', _, 64, 4) _ = O.max_pooling2d('pool2', _, 2) _ = O.conv2d('conv3', _, 64, 3) return _ def forward(x, t1, t2): dpc.add_output(forward_conv(x), name='feature') dpc.add_output(forward_conv(t1), name='t1_feature') dpc.add_output(forward_conv(t2), name='t2_feature') dpc.set_input_maker(inputs).set_forward_func(forward) @O.auto_reuse def forward_fc(feature, action): action = O.one_hot(action, get_player_nr_actions()) _ = O.concat([feature.flatten2(), action], axis=1) _ = O.fc('fc0', _, 512, nonlin=O.p_relu) reward = O.fc('fc_reward', _, 1) return reward action = O.placeholder('action', shape=(None, ), dtype='int64') net.add_output(forward_fc(dpc.outputs['feature'], action), name='reward') if is_train: t1_action = O.placeholder('t1_action', shape=(None, ), dtype='int64') t1_reward_exp = O.exp( forward_fc(dpc.outputs['t1_feature'], t1_action).sum()) t2_action = O.placeholder('t2_action', shape=(None, ), dtype='int64') t2_reward_exp = O.exp( forward_fc(dpc.outputs['t2_feature'], t2_action).sum()) pref = O.placeholder('pref') pref = O.callback_injector(pref) p1, p2 = 1 - pref, pref p_greater = t1_reward_exp / (t1_reward_exp + t2_reward_exp) loss = -p1 * O.log(p_greater) - p2 * O.log(1 - p_greater) net.set_loss(loss)
def make_network(env): with env.create_network() as net: net.dist = O.distrib.GaussianDistribution('policy', size=get_action_shape()[0], fixed_std=False) state = O.placeholder('state', shape=(None, ) + get_input_shape()) batch_size = state.shape[0] # We have to define variable scope here for later optimization. with env.variable_scope('policy'): _ = state _ = O.fc('fc1', _, 64, nonlin=O.relu) _ = O.fc('fc2', _, 64, nonlin=O.relu) mu = O.fc('fc_mu', _, net.dist.sample_size, nonlin=O.tanh) logstd = O.variable('logstd', O.truncated_normal_initializer(stddev=0.01), shape=(net.dist.sample_size, ), trainable=True) logstd = O.tile(logstd.add_axis(0), [batch_size, 1]) theta = O.concat([mu, logstd], axis=1) policy = net.dist.sample(batch_size=batch_size, theta=theta, process_theta=True) policy = O.clip_by_value(policy, -1, 1) net.add_output(theta, name='theta') net.add_output(policy, name='policy') if env.phase == env.Phase.TRAIN: theta_old = O.placeholder('theta_old', shape=(None, net.dist.param_size)) action = O.placeholder('action', shape=(None, net.dist.sample_size)) advantage = O.placeholder('advantage', shape=(None, )) entropy_beta = O.scalar('entropy_beta', g.entropy_beta) log_prob = net.dist.log_likelihood(action, theta, process_theta=True) log_prob_old = net.dist.log_likelihood(action, theta_old, process_theta=True) ratio = O.exp(log_prob - log_prob_old) epsilon = get_env('ppo.epsilon') surr1 = ratio * advantage # surrogate from conservative policy iteration surr2 = O.clip_by_value(ratio, 1.0 - epsilon, 1.0 + epsilon) * advantage policy_loss = -O.reduce_mean(O.min( surr1, surr2)) # PPO's pessimistic surrogate (L^CLIP) entropy = net.dist.entropy(theta, process_theta=True).mean() entropy_loss = -entropy_beta * entropy net.add_output(policy_loss, name='policy_loss') net.add_output(entropy_loss, name='entropy_loss') summary.scalar('policy_entropy', entropy) with env.variable_scope('value'): _ = state _ = O.fc('fc1', _, 64, nonlin=O.relu) _ = O.fc('fc2', _, 64, nonlin=O.relu) value = O.fc('fcv', _, 1) value = value.remove_axis(1) net.add_output(value, name='value') if env.phase == env.Phase.TRAIN: value_label = O.placeholder('value_label', shape=(None, )) value_old = O.placeholder('value_old', shape=(None, )) value_surr1 = O.raw_l2_loss('raw_value_loss_surr1', value, value_label) value_clipped = value_old + O.clip_by_value( value - value_old, -epsilon, epsilon) value_surr2 = O.raw_l2_loss('raw_value_loss_surr2', value_clipped, value_label) value_loss = O.reduce_mean(O.max(value_surr1, value_surr2)) net.add_output(value_loss, name='value_loss') if env.phase == env.Phase.TRAIN: loss = O.identity(policy_loss + entropy_loss + value_loss, name='total_loss') net.set_loss(loss)
def normal_pdf(x, mu, var): exponent = ((x - mu) ** 2.) / (var + 1e-4) prob = (1. / (2. * np.pi * var)) * O.exp(-exponent) return prob
def make_network(env): use_linear_vr = get_env('trpo.use_linear_vr') with env.create_network() as net: net.dist = O.distrib.GaussianDistribution('policy', size=get_action_shape()[0], fixed_std=False) if use_linear_vr: from tartist.app.rl.utils.math import LinearValueRegressor net.value_regressor = LinearValueRegressor() state = O.placeholder('state', shape=(None, ) + get_input_shape()) # state = O.moving_average(state) # state = O.clip_by_value(state, -10, 10) batch_size = state.shape[0] # We have to define variable scope here for later optimization. with env.variable_scope('policy'): _ = state with O.argscope(O.fc): _ = O.fc('fc1', _, 64, nonlin=O.relu) _ = O.fc('fc2', _, 64, nonlin=O.relu) mu = O.fc('fc_mu', _, net.dist.sample_size, nonlin=O.tanh) logstd = O.variable( 'logstd', O.truncated_normal_initializer(stddev=0.01), shape=(net.dist.sample_size, ), trainable=True) logstd = O.tile(logstd.add_axis(0), [batch_size, 1]) theta = O.concat([mu, logstd], axis=1) policy = net.dist.sample(batch_size=batch_size, theta=theta, process_theta=True) policy = O.clip_by_value(policy, -1, 1) net.add_output(theta, name='theta') net.add_output(policy, name='policy') if env.phase == env.Phase.TRAIN: theta_old = O.placeholder('theta_old', shape=(None, net.dist.param_size)) action = O.placeholder('action', shape=(None, net.dist.sample_size)) advantage = O.placeholder('advantage', shape=(None, )) log_prob = net.dist.log_likelihood(action, theta, process_theta=True) log_prob_old = net.dist.log_likelihood(action, theta_old, process_theta=True) # Importance sampling of surrogate loss (L in paper). ratio = O.exp(log_prob - log_prob_old) policy_loss = -O.reduce_mean(ratio * advantage) kl = net.dist.kl(theta_p=theta_old, theta_q=theta, process_theta=True).mean() kl_self = net.dist.kl(theta_p=O.zero_grad(theta), theta_q=theta, process_theta=True).mean() entropy = net.dist.entropy(theta, process_theta=True).mean() net.add_output(policy_loss, name='policy_loss') net.add_output(kl, name='kl') net.add_output(kl_self, name='kl_self') summary.scalar('policy_entropy', entropy, collections=[rl.train.ACGraphKeys.POLICY_SUMMARIES]) if not use_linear_vr: with env.variable_scope('value'): value = O.fc('fcv', state, 1) net.add_output(value, name='value') if env.phase == env.Phase.TRAIN: value_label = O.placeholder('value_label', shape=(None, )) value_loss = O.raw_l2_loss('raw_value_loss', value, value_label).mean(name='value_loss') net.add_output(value_loss, name='value_loss')
def forward(img=None): encoder = O.BasicLSTMCell(256) decoder = O.BasicLSTMCell(256) batch_size = img.shape[0] if is_train else 1 canvas = O.zeros(shape=O.canonize_sym_shape([batch_size, h, w, c]), dtype='float32') enc_state = encoder.zero_state(batch_size, dtype='float32') dec_state = decoder.zero_state(batch_size, dtype='float32') enc_h, dec_h = enc_state[1], dec_state[1] def encode(x, state, reuse): with env.variable_scope('read_encoder', reuse=reuse): return encoder(x, state) def decode(x, state, reuse): with env.variable_scope('write_decoder', reuse=reuse): return decoder(x, state) all_sqr_mus, all_vars, all_log_vars = 0., 0., 0. for step in range(nr_glimpse): reuse = (step != 0) if is_reconstruct or env.phase is env.Phase.TRAIN: img_hat = draw_opr.image_diff(img, canvas) # eq. 3 # Note: here the input should be dec_h with env.variable_scope('read', reuse=reuse): read_param = O.fc('fc_param', dec_h, 5) with env.name_scope('read_step{}'.format(step)): cx, cy, delta, var, gamma = draw_opr.split_att_params(h, w, att_dim, read_param) read_inp = O.concat([img, img_hat], axis=3) # of shape: batch_size x h x w x (2c) read_out = draw_opr.att_read(att_dim, read_inp, cx, cy, delta, var) # eq. 4 enc_inp = O.concat([gamma * read_out.flatten2(), dec_h], axis=1) enc_h, enc_state = encode(enc_inp, enc_state, reuse) # eq. 5 with env.variable_scope('sample', reuse=reuse): _ = enc_h sample_mu = O.fc('fc_mu', _, code_length) sample_log_var = O.fc('fc_sigma', _, code_length) with env.name_scope('sample_step{}'.format(step)): sample_var = O.exp(sample_log_var) sample_std = O.sqrt(sample_var) sample_epsilon = O.random_normal([batch_size, code_length]) z = sample_mu + sample_std * sample_epsilon # eq. 6 # accumulate for losses all_sqr_mus += sample_mu ** 2. all_vars += sample_var all_log_vars += sample_log_var else: z = O.random_normal([1, code_length]) # z = O.callback_injector(z) dec_h, dec_state = decode(z, dec_state, reuse) # eq. 7 with env.variable_scope('write', reuse=reuse): write_param = O.fc('fc_param', dec_h, 5) write_in = O.fc('fc', dec_h, (att_dim * att_dim * c)).reshape(-1, att_dim, att_dim, c) with env.name_scope('write_step{}'.format(step)): cx, cy, delta, var, gamma = draw_opr.split_att_params(h, w, att_dim, write_param) write_out = draw_opr.att_write(h, w, write_in, cx, cy, delta, var) # eq. 8 canvas += write_out if env.phase is env.Phase.TEST: dpc.add_output(O.sigmoid(canvas), name='canvas_step{}'.format(step)) canvas = O.sigmoid(canvas) if env.phase is env.Phase.TRAIN: with env.variable_scope('loss'): img, canvas = img.flatten2(), canvas.flatten2() content_loss = O.raw_cross_entropy_prob('raw_content', canvas, img) content_loss = content_loss.sum(axis=1).mean(name='content') # distrib_loss = 0.5 * (O.sqr(mu) + O.sqr(std) - 2. * O.log(std + 1e-8) - 1.0).sum(axis=1) distrib_loss = -0.5 * (float(nr_glimpse) + all_log_vars - all_sqr_mus - all_vars).sum(axis=1) distrib_loss = distrib_loss.mean(name='distrib') summary.scalar('content_loss', content_loss) summary.scalar('distrib_loss', distrib_loss) loss = content_loss + distrib_loss dpc.add_output(loss, name='loss', reduce_method='sum') dpc.add_output(canvas, name='output')