Exemplo n.º 1
0
def filterbank(img_h, img_w, att_dim, center_x, center_y, delta, var):
    """
    Get the filterbank matrix.

    :param img_h: image height
    :param img_w: image width
    :param att_dim: attention dim, the attention window is att_dim x att_dim
    :param center_x: attention center (x-axis): batch_size x 1
    :param center_y: attention center (y-axis): batch_size x 1
    :param delta: stride: batch_size x 1
    :param var: variance (sigma^2): batch_size x 1
    :return: filter_x, filter_y
    """

    with tf.name_scope('filterbank'):
        rng = O.range(1, att_dim+1, dtype='float32') - att_dim / 2 + 0.5
        mu_x = center_x + rng * delta
        mu_y = center_y + rng * delta
        all_x = O.range(1, img_w+1, dtype='float32')
        all_y = O.range(1, img_h+1, dtype='float32')
        a = all_x - mu_x.add_axis(-1)
        b = all_y - mu_y.add_axis(-1)

        fx = O.exp(-O.sqr(a) / var.add_axis(-1) / 2.)
        fy = O.exp(-O.sqr(b) / var.add_axis(-1) / 2.)

        fx /= fx.sum(axis=2, keepdims=True) + 1e-8
        fy /= fy.sum(axis=2, keepdims=True) + 1e-8

        fx = O.as_varnode(fx)

        return fx, fy
Exemplo n.º 2
0
def split_att_params(img_h, img_w, att_dim, value):
    """
    Split attention params.

    :param img_h: image height
    :param img_w: image width
    :param att_dim: attention dim, the attention window is att_dim x att_dim
    :param value: attention params produced by hidden layer: batch_size, 5
    :return: center_x, center_y, delta, variance, gamma
    """

    with tf.name_scope('split_att_params'):
        center_x, center_y, log_delta, log_var, log_gamma = O.split(value, 5, axis=1)

        delta, var, gamma = map(lambda x: O.exp(x).reshape(-1, 1), [log_delta, log_var, log_gamma])
        center_x = ((img_w + 1.) * (center_x + 1.) / 2.).reshape(-1, 1)
        center_y = ((img_h + 1.) * (center_y + 1.) / 2.).reshape(-1, 1)
        delta *= float((max(img_h, img_w) - 1) / (att_dim - 1))
        return center_x, center_y, delta, var, gamma
            def forward(x):
                if is_reconstruct or env.phase is env.Phase.TRAIN:
                    with env.variable_scope('encoder'):
                        _ = x
                        _ = O.fc('fc1', _, 500, nonlin=O.tanh)
                        _ = O.fc('fc2', _, 500, nonlin=O.tanh)
                        mu = O.fc('fc3_mu', _, code_length)
                        log_var = O.fc('fc3_sigma', _, code_length)
                        var = O.exp(log_var)
                        std = O.sqrt(var)
                        epsilon = O.random_normal([x.shape[0], code_length])
                        z_given_x = mu + std * epsilon
                else:
                    z_given_x = O.random_normal([1, code_length])

                with env.variable_scope('decoder'):
                    _ = z_given_x
                    _ = O.fc('fc1', _, 500, nonlin=O.tanh)
                    _ = O.fc('fc2', _, 500, nonlin=O.tanh)
                    _ = O.fc('fc3', _, 784, nonlin=O.sigmoid)
                    _ = _.reshape(-1, h, w, c)
                    x_given_z = _

                if env.phase is env.Phase.TRAIN:
                    with env.variable_scope('loss'):
                        content_loss = O.raw_cross_entropy_prob(
                            'raw_content', x_given_z.flatten2(), x.flatten2())
                        content_loss = content_loss.sum(axis=1).mean(
                            name='content')
                        # distrib_loss = 0.5 * (O.sqr(mu) + O.sqr(std) - 2. * O.log(std + 1e-8) - 1.0).sum(axis=1)
                        distrib_loss = -0.5 * (1. + log_var - O.sqr(mu) -
                                               var).sum(axis=1)
                        distrib_loss = distrib_loss.mean(name='distrib')

                        loss = content_loss + distrib_loss
                    dpc.add_output(loss, name='loss', reduce_method='sum')

                dpc.add_output(x_given_z, name='output')
Exemplo n.º 4
0
def make_rpredictor_network(env):
    is_train = env.phase is env.Phase.TRAIN

    with env.create_network() as net:
        h, w, c = get_input_shape()
        # Hack(MJY):: forced RGB input (instead of combination of history frames)
        c = 3

        dpc = env.create_dpcontroller()
        with dpc.activate():

            def inputs():
                state = O.placeholder('state', shape=(None, h, w, c))
                t1_state = O.placeholder('t1_state', shape=(None, h, w, c))
                t2_state = O.placeholder('t2_state', shape=(None, h, w, c))
                return [state, t1_state, t2_state]

            @O.auto_reuse
            def forward_conv(x):
                _ = x / 255.0
                with O.argscope(O.conv2d, nonlin=O.relu):
                    _ = O.conv2d('conv0', _, 32, 5)
                    _ = O.max_pooling2d('pool0', _, 2)
                    _ = O.conv2d('conv1', _, 32, 5)
                    _ = O.max_pooling2d('pool1', _, 2)
                    _ = O.conv2d('conv2', _, 64, 4)
                    _ = O.max_pooling2d('pool2', _, 2)
                    _ = O.conv2d('conv3', _, 64, 3)
                return _

            def forward(x, t1, t2):
                dpc.add_output(forward_conv(x), name='feature')
                dpc.add_output(forward_conv(t1), name='t1_feature')
                dpc.add_output(forward_conv(t2), name='t2_feature')

            dpc.set_input_maker(inputs).set_forward_func(forward)

        @O.auto_reuse
        def forward_fc(feature, action):
            action = O.one_hot(action, get_player_nr_actions())
            _ = O.concat([feature.flatten2(), action], axis=1)
            _ = O.fc('fc0', _, 512, nonlin=O.p_relu)
            reward = O.fc('fc_reward', _, 1)
            return reward

        action = O.placeholder('action', shape=(None, ), dtype='int64')
        net.add_output(forward_fc(dpc.outputs['feature'], action),
                       name='reward')

        if is_train:
            t1_action = O.placeholder('t1_action',
                                      shape=(None, ),
                                      dtype='int64')
            t1_reward_exp = O.exp(
                forward_fc(dpc.outputs['t1_feature'], t1_action).sum())
            t2_action = O.placeholder('t2_action',
                                      shape=(None, ),
                                      dtype='int64')
            t2_reward_exp = O.exp(
                forward_fc(dpc.outputs['t2_feature'], t2_action).sum())

            pref = O.placeholder('pref')
            pref = O.callback_injector(pref)
            p1, p2 = 1 - pref, pref

            p_greater = t1_reward_exp / (t1_reward_exp + t2_reward_exp)
            loss = -p1 * O.log(p_greater) - p2 * O.log(1 - p_greater)

            net.set_loss(loss)
def make_network(env):
    with env.create_network() as net:
        net.dist = O.distrib.GaussianDistribution('policy',
                                                  size=get_action_shape()[0],
                                                  fixed_std=False)

        state = O.placeholder('state', shape=(None, ) + get_input_shape())
        batch_size = state.shape[0]

        # We have to define variable scope here for later optimization.

        with env.variable_scope('policy'):
            _ = state

            _ = O.fc('fc1', _, 64, nonlin=O.relu)
            _ = O.fc('fc2', _, 64, nonlin=O.relu)
            mu = O.fc('fc_mu', _, net.dist.sample_size, nonlin=O.tanh)
            logstd = O.variable('logstd',
                                O.truncated_normal_initializer(stddev=0.01),
                                shape=(net.dist.sample_size, ),
                                trainable=True)

            logstd = O.tile(logstd.add_axis(0), [batch_size, 1])
            theta = O.concat([mu, logstd], axis=1)

            policy = net.dist.sample(batch_size=batch_size,
                                     theta=theta,
                                     process_theta=True)
            policy = O.clip_by_value(policy, -1, 1)

            net.add_output(theta, name='theta')
            net.add_output(policy, name='policy')

        if env.phase == env.Phase.TRAIN:
            theta_old = O.placeholder('theta_old',
                                      shape=(None, net.dist.param_size))
            action = O.placeholder('action',
                                   shape=(None, net.dist.sample_size))
            advantage = O.placeholder('advantage', shape=(None, ))
            entropy_beta = O.scalar('entropy_beta', g.entropy_beta)

            log_prob = net.dist.log_likelihood(action,
                                               theta,
                                               process_theta=True)
            log_prob_old = net.dist.log_likelihood(action,
                                                   theta_old,
                                                   process_theta=True)

            ratio = O.exp(log_prob - log_prob_old)
            epsilon = get_env('ppo.epsilon')
            surr1 = ratio * advantage  # surrogate from conservative policy iteration
            surr2 = O.clip_by_value(ratio, 1.0 - epsilon,
                                    1.0 + epsilon) * advantage
            policy_loss = -O.reduce_mean(O.min(
                surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
            entropy = net.dist.entropy(theta, process_theta=True).mean()
            entropy_loss = -entropy_beta * entropy

            net.add_output(policy_loss, name='policy_loss')
            net.add_output(entropy_loss, name='entropy_loss')

            summary.scalar('policy_entropy', entropy)

        with env.variable_scope('value'):
            _ = state
            _ = O.fc('fc1', _, 64, nonlin=O.relu)
            _ = O.fc('fc2', _, 64, nonlin=O.relu)
            value = O.fc('fcv', _, 1)
            value = value.remove_axis(1)
            net.add_output(value, name='value')

        if env.phase == env.Phase.TRAIN:
            value_label = O.placeholder('value_label', shape=(None, ))
            value_old = O.placeholder('value_old', shape=(None, ))

            value_surr1 = O.raw_l2_loss('raw_value_loss_surr1', value,
                                        value_label)
            value_clipped = value_old + O.clip_by_value(
                value - value_old, -epsilon, epsilon)
            value_surr2 = O.raw_l2_loss('raw_value_loss_surr2', value_clipped,
                                        value_label)
            value_loss = O.reduce_mean(O.max(value_surr1, value_surr2))
            net.add_output(value_loss, name='value_loss')

        if env.phase == env.Phase.TRAIN:
            loss = O.identity(policy_loss + entropy_loss + value_loss,
                              name='total_loss')
            net.set_loss(loss)
Exemplo n.º 6
0
def normal_pdf(x, mu, var):
    exponent = ((x - mu) ** 2.) / (var + 1e-4)
    prob = (1. / (2. * np.pi * var)) * O.exp(-exponent)
    return prob
Exemplo n.º 7
0
def make_network(env):
    use_linear_vr = get_env('trpo.use_linear_vr')

    with env.create_network() as net:
        net.dist = O.distrib.GaussianDistribution('policy',
                                                  size=get_action_shape()[0],
                                                  fixed_std=False)
        if use_linear_vr:
            from tartist.app.rl.utils.math import LinearValueRegressor
            net.value_regressor = LinearValueRegressor()

        state = O.placeholder('state', shape=(None, ) + get_input_shape())
        # state = O.moving_average(state)
        # state = O.clip_by_value(state, -10, 10)
        batch_size = state.shape[0]

        # We have to define variable scope here for later optimization.

        with env.variable_scope('policy'):
            _ = state

            with O.argscope(O.fc):
                _ = O.fc('fc1', _, 64, nonlin=O.relu)
                _ = O.fc('fc2', _, 64, nonlin=O.relu)
                mu = O.fc('fc_mu', _, net.dist.sample_size, nonlin=O.tanh)
                logstd = O.variable(
                    'logstd',
                    O.truncated_normal_initializer(stddev=0.01),
                    shape=(net.dist.sample_size, ),
                    trainable=True)

            logstd = O.tile(logstd.add_axis(0), [batch_size, 1])
            theta = O.concat([mu, logstd], axis=1)

            policy = net.dist.sample(batch_size=batch_size,
                                     theta=theta,
                                     process_theta=True)
            policy = O.clip_by_value(policy, -1, 1)

            net.add_output(theta, name='theta')
            net.add_output(policy, name='policy')

        if env.phase == env.Phase.TRAIN:
            theta_old = O.placeholder('theta_old',
                                      shape=(None, net.dist.param_size))
            action = O.placeholder('action',
                                   shape=(None, net.dist.sample_size))
            advantage = O.placeholder('advantage', shape=(None, ))

            log_prob = net.dist.log_likelihood(action,
                                               theta,
                                               process_theta=True)
            log_prob_old = net.dist.log_likelihood(action,
                                                   theta_old,
                                                   process_theta=True)

            # Importance sampling of surrogate loss (L in paper).
            ratio = O.exp(log_prob - log_prob_old)
            policy_loss = -O.reduce_mean(ratio * advantage)

            kl = net.dist.kl(theta_p=theta_old,
                             theta_q=theta,
                             process_theta=True).mean()
            kl_self = net.dist.kl(theta_p=O.zero_grad(theta),
                                  theta_q=theta,
                                  process_theta=True).mean()
            entropy = net.dist.entropy(theta, process_theta=True).mean()

            net.add_output(policy_loss, name='policy_loss')
            net.add_output(kl, name='kl')
            net.add_output(kl_self, name='kl_self')

            summary.scalar('policy_entropy',
                           entropy,
                           collections=[rl.train.ACGraphKeys.POLICY_SUMMARIES])

        if not use_linear_vr:
            with env.variable_scope('value'):
                value = O.fc('fcv', state, 1)
                net.add_output(value, name='value')

            if env.phase == env.Phase.TRAIN:
                value_label = O.placeholder('value_label', shape=(None, ))
                value_loss = O.raw_l2_loss('raw_value_loss', value,
                                           value_label).mean(name='value_loss')
                net.add_output(value_loss, name='value_loss')
Exemplo n.º 8
0
            def forward(img=None):
                encoder = O.BasicLSTMCell(256)
                decoder = O.BasicLSTMCell(256)

                batch_size = img.shape[0] if is_train else 1

                canvas = O.zeros(shape=O.canonize_sym_shape([batch_size, h, w, c]), dtype='float32')
                enc_state = encoder.zero_state(batch_size, dtype='float32')
                dec_state = decoder.zero_state(batch_size, dtype='float32')
                enc_h, dec_h = enc_state[1], dec_state[1]

                def encode(x, state, reuse):
                    with env.variable_scope('read_encoder', reuse=reuse):
                        return encoder(x, state)

                def decode(x, state, reuse):
                    with env.variable_scope('write_decoder', reuse=reuse):
                        return decoder(x, state)

                all_sqr_mus, all_vars, all_log_vars = 0., 0., 0.

                for step in range(nr_glimpse):
                    reuse = (step != 0)
                    if is_reconstruct or env.phase is env.Phase.TRAIN:
                        img_hat = draw_opr.image_diff(img, canvas)  # eq. 3

                        # Note: here the input should be dec_h
                        with env.variable_scope('read', reuse=reuse):
                            read_param = O.fc('fc_param', dec_h, 5)

                        with env.name_scope('read_step{}'.format(step)):
                            cx, cy, delta, var, gamma = draw_opr.split_att_params(h, w, att_dim, read_param)
                            read_inp = O.concat([img, img_hat], axis=3)  # of shape: batch_size x h x w x (2c)
                            read_out = draw_opr.att_read(att_dim, read_inp, cx, cy, delta, var)  # eq. 4
                            enc_inp = O.concat([gamma * read_out.flatten2(), dec_h], axis=1)
                        enc_h, enc_state = encode(enc_inp, enc_state, reuse)  # eq. 5

                        with env.variable_scope('sample', reuse=reuse):
                            _ = enc_h
                            sample_mu = O.fc('fc_mu', _, code_length)
                            sample_log_var = O.fc('fc_sigma', _, code_length)

                        with env.name_scope('sample_step{}'.format(step)):
                            sample_var = O.exp(sample_log_var)
                            sample_std = O.sqrt(sample_var)
                            sample_epsilon = O.random_normal([batch_size, code_length])
                            z = sample_mu + sample_std * sample_epsilon  # eq. 6

                        # accumulate for losses
                        all_sqr_mus += sample_mu ** 2.
                        all_vars += sample_var
                        all_log_vars += sample_log_var
                    else:
                        z = O.random_normal([1, code_length])

                    # z = O.callback_injector(z)

                    dec_h, dec_state = decode(z, dec_state, reuse)  # eq. 7
                    with env.variable_scope('write', reuse=reuse):
                        write_param = O.fc('fc_param', dec_h, 5)
                        write_in = O.fc('fc', dec_h, (att_dim * att_dim * c)).reshape(-1, att_dim, att_dim, c)

                    with env.name_scope('write_step{}'.format(step)):
                        cx, cy, delta, var, gamma = draw_opr.split_att_params(h, w, att_dim, write_param)
                        write_out = draw_opr.att_write(h, w, write_in, cx, cy, delta, var)  # eq. 8

                    canvas += write_out

                    if env.phase is env.Phase.TEST:
                        dpc.add_output(O.sigmoid(canvas), name='canvas_step{}'.format(step))

                canvas = O.sigmoid(canvas)

                if env.phase is env.Phase.TRAIN:
                    with env.variable_scope('loss'):
                        img, canvas = img.flatten2(), canvas.flatten2()
                        content_loss = O.raw_cross_entropy_prob('raw_content', canvas, img)
                        content_loss = content_loss.sum(axis=1).mean(name='content')
                        # distrib_loss = 0.5 * (O.sqr(mu) + O.sqr(std) - 2. * O.log(std + 1e-8) - 1.0).sum(axis=1)
                        distrib_loss = -0.5 * (float(nr_glimpse) + all_log_vars - all_sqr_mus - all_vars).sum(axis=1)
                        distrib_loss = distrib_loss.mean(name='distrib')

                        summary.scalar('content_loss', content_loss)
                        summary.scalar('distrib_loss', distrib_loss)

                        loss = content_loss + distrib_loss
                    dpc.add_output(loss, name='loss', reduce_method='sum')

                dpc.add_output(canvas, name='output')