예제 #1
0
파일: one_shot.py 프로젝트: sbos/gmn
    def generate(self, timestep, dummy=True):
        def prior_query(state):
            return self._prior_query(input=state)

        def prior_strength(state):
            return self._prior_strength(input=state)

        prior_response, prior_state = self.prior_repr.recognize(
            self.features,
            timestep,
            prior_query,
            self.prior_steps,
            strength=prior_strength,
            dummy=dummy)

        z_prior = self.gen.generate_prior(
            scg.concat([prior_response, prior_state]),
            VAE.hidden_name(timestep))

        def gen_query(state):
            return self._gen_query(input=scg.concat([state, z_prior]))

        def gen_strength(state):
            return self._gen_strength(input=state)

        gen_response, gen_state = self.set_repr.recognize(
            self.features,
            timestep,
            gen_query,
            self.num_steps,
            strength=gen_strength,
            dummy=dummy)
        return self.gen.generate(z_prior, scg.concat([gen_response,
                                                      gen_state]),
                                 VAE.observed_name(timestep))
예제 #2
0
파일: one_shot.py 프로젝트: sbos/gmn
 def recognize(self, h, param, hidden_name):
     h = scg.concat([h, param])
     mu = self.mu(input=h, name=hidden_name + '_mu')
     sigma = self.sigma(input=h, name=hidden_name + '_sigma')
     z = scg.Normal(self.hidden_dim)(mu=mu,
                                     pre_sigma=sigma,
                                     name=hidden_name)
     return z
예제 #3
0
파일: one_shot.py 프로젝트: sbos/gmn
    def generate(self, z, param, observed_name):
        h = self.h0(input=scg.concat([z, param]))
        h = self.h1(h)
        h = self.h2(h)
        h = self.h3(h)

        h = self.conv(input=h, name=observed_name + '_logit')

        return scg.Bernoulli()(logit=h, name=observed_name)
예제 #4
0
파일: utils.py 프로젝트: sbos/gmn
    def build(entries):
        mem = []
        for entry in entries:

            def transform(input=None):
                return tf.expand_dims(input, 1)

            entry = scg.apply(transform, input=entry)
            mem.append(entry)
        return scg.concat(mem, 1)
예제 #5
0
파일: utils.py 프로젝트: sbos/gmn
    def recognize(self,
                  obs,
                  timestep,
                  query,
                  num_steps,
                  dummy=True,
                  strength=lambda state: 1.):
        # assert num_steps > 0
        state = scg.batch_repeat(self.init_state, obs[0])

        data = obs[:timestep]
        if dummy:
            data += [
                scg.batch_repeat(dummy, state) for dummy in self.dummy_proto
            ]
        proto_mem = Memory.build(data)

        data = [self.match(input=obs[t]) for t in xrange(timestep)]
        if dummy:
            data += [
                scg.batch_repeat(dummy, state) for dummy in self.dummy_match
            ]
        match_mem = Memory.build(data)

        if num_steps == 0:

            def avg(input=None):
                return tf.reduce_mean(input, 1)

            r = scg.apply(avg, input=proto_mem)
            state = self.cell(input=scg.concat([r, state]), state=state)
            return r, state

        r = None
        for step in xrange(num_steps):
            q = query(state)
            a = scg.Attention()(mem=match_mem, key=q, strength=strength(state))
            r = scg.AttentiveReader()(attention=a, mem=proto_mem)
            state = self.cell(input=scg.concat([r, state]), state=state)

        return r, state
예제 #6
0
파일: one_shot.py 프로젝트: sbos/gmn
 def gen_query(state):
     return self._gen_query(input=scg.concat([state, z_prior]))
예제 #7
0
파일: one_shot.py 프로젝트: sbos/gmn
 def rec_query(state):
     return self._rec_query(
         input=scg.concat([state, self.features[timestep]]))
예제 #8
0
파일: one_shot.py 프로젝트: sbos/gmn
    def __init__(self, input_data, hidden_dim, gen, rec):
        state_dim = 200
        self.num_steps = args.hops
        self.prior_steps = args.prior_hops
        self.matching_dim = 200

        with tf.variable_scope('recognition') as vs:
            self.rec = rec(hidden_dim, state_dim + 288)
            self.features_dim = self.rec.features_dim
            self._rec_query = scg.Affine(state_dim + self.features_dim,
                                         self.matching_dim,
                                         fun=None,
                                         init=scg.norm_init(scg.he_normal))
            self._rec_strength = scg.Affine(state_dim,
                                            1,
                                            init=scg.norm_init(scg.he_normal))

        with tf.variable_scope('generation') as vs:
            self.gen = gen(hidden_dim, state_dim + self.features_dim)
            self._gen_query = scg.Affine(state_dim + hidden_dim,
                                         self.matching_dim,
                                         fun=None,
                                         init=scg.norm_init(scg.he_normal))
            self._gen_strength = scg.Affine(state_dim,
                                            1,
                                            init=scg.norm_init(scg.he_normal))

            self._prior_query = scg.Affine(state_dim,
                                           self.matching_dim,
                                           fun=None,
                                           init=scg.norm_init(scg.he_normal))
            self._prior_strength = scg.Affine(state_dim,
                                              1,
                                              init=scg.norm_init(
                                                  scg.he_normal))
            self.prior_repr = SetRepresentation(self.features_dim,
                                                self.matching_dim, state_dim)

        with tf.variable_scope('both') as vs:
            self.set_repr = SetRepresentation(self.features_dim,
                                              self.matching_dim, state_dim)

        self.z = [None] * episode_length
        self.x = [None] * (episode_length + 1)

        # allocating observations

        self.obs = [None] * episode_length
        for t in xrange(episode_length):
            current_data = input_data[:, t, :]
            self.obs[t] = scg.Constant(value=current_data,
                                       shape=[28 * 28
                                              ])(name=VAE.observed_name(t))

        # pre-computing features
        self.features = []
        for t in xrange(episode_length):
            self.features.append(self.rec.get_features(self.obs[t]))

        for timestep in xrange(episode_length + 1):
            dummy = True
            if args.no_dummy and timestep > 0:
                dummy = False

            if timestep < episode_length:

                def rec_query(state):
                    return self._rec_query(
                        input=scg.concat([state, self.features[timestep]]))

                def rec_strength(state):
                    return self._rec_strength(input=state)

                rec_response, rec_state = self.set_repr.recognize(
                    self.features,
                    timestep,
                    rec_query,
                    self.num_steps,
                    strength=rec_strength,
                    dummy=dummy)

                self.z[timestep] = self.rec.recognize(
                    self.features[timestep],
                    scg.concat([rec_response, rec_state]),
                    VAE.hidden_name(timestep))

            self.x[timestep] = self.generate(timestep, dummy=dummy)