def generate(self, timestep, dummy=True): def prior_query(state): return self._prior_query(input=state) def prior_strength(state): return self._prior_strength(input=state) prior_response, prior_state = self.prior_repr.recognize( self.features, timestep, prior_query, self.prior_steps, strength=prior_strength, dummy=dummy) z_prior = self.gen.generate_prior( scg.concat([prior_response, prior_state]), VAE.hidden_name(timestep)) def gen_query(state): return self._gen_query(input=scg.concat([state, z_prior])) def gen_strength(state): return self._gen_strength(input=state) gen_response, gen_state = self.set_repr.recognize( self.features, timestep, gen_query, self.num_steps, strength=gen_strength, dummy=dummy) return self.gen.generate(z_prior, scg.concat([gen_response, gen_state]), VAE.observed_name(timestep))
def recognize(self, h, param, hidden_name): h = scg.concat([h, param]) mu = self.mu(input=h, name=hidden_name + '_mu') sigma = self.sigma(input=h, name=hidden_name + '_sigma') z = scg.Normal(self.hidden_dim)(mu=mu, pre_sigma=sigma, name=hidden_name) return z
def generate(self, z, param, observed_name): h = self.h0(input=scg.concat([z, param])) h = self.h1(h) h = self.h2(h) h = self.h3(h) h = self.conv(input=h, name=observed_name + '_logit') return scg.Bernoulli()(logit=h, name=observed_name)
def build(entries): mem = [] for entry in entries: def transform(input=None): return tf.expand_dims(input, 1) entry = scg.apply(transform, input=entry) mem.append(entry) return scg.concat(mem, 1)
def recognize(self, obs, timestep, query, num_steps, dummy=True, strength=lambda state: 1.): # assert num_steps > 0 state = scg.batch_repeat(self.init_state, obs[0]) data = obs[:timestep] if dummy: data += [ scg.batch_repeat(dummy, state) for dummy in self.dummy_proto ] proto_mem = Memory.build(data) data = [self.match(input=obs[t]) for t in xrange(timestep)] if dummy: data += [ scg.batch_repeat(dummy, state) for dummy in self.dummy_match ] match_mem = Memory.build(data) if num_steps == 0: def avg(input=None): return tf.reduce_mean(input, 1) r = scg.apply(avg, input=proto_mem) state = self.cell(input=scg.concat([r, state]), state=state) return r, state r = None for step in xrange(num_steps): q = query(state) a = scg.Attention()(mem=match_mem, key=q, strength=strength(state)) r = scg.AttentiveReader()(attention=a, mem=proto_mem) state = self.cell(input=scg.concat([r, state]), state=state) return r, state
def gen_query(state): return self._gen_query(input=scg.concat([state, z_prior]))
def rec_query(state): return self._rec_query( input=scg.concat([state, self.features[timestep]]))
def __init__(self, input_data, hidden_dim, gen, rec): state_dim = 200 self.num_steps = args.hops self.prior_steps = args.prior_hops self.matching_dim = 200 with tf.variable_scope('recognition') as vs: self.rec = rec(hidden_dim, state_dim + 288) self.features_dim = self.rec.features_dim self._rec_query = scg.Affine(state_dim + self.features_dim, self.matching_dim, fun=None, init=scg.norm_init(scg.he_normal)) self._rec_strength = scg.Affine(state_dim, 1, init=scg.norm_init(scg.he_normal)) with tf.variable_scope('generation') as vs: self.gen = gen(hidden_dim, state_dim + self.features_dim) self._gen_query = scg.Affine(state_dim + hidden_dim, self.matching_dim, fun=None, init=scg.norm_init(scg.he_normal)) self._gen_strength = scg.Affine(state_dim, 1, init=scg.norm_init(scg.he_normal)) self._prior_query = scg.Affine(state_dim, self.matching_dim, fun=None, init=scg.norm_init(scg.he_normal)) self._prior_strength = scg.Affine(state_dim, 1, init=scg.norm_init( scg.he_normal)) self.prior_repr = SetRepresentation(self.features_dim, self.matching_dim, state_dim) with tf.variable_scope('both') as vs: self.set_repr = SetRepresentation(self.features_dim, self.matching_dim, state_dim) self.z = [None] * episode_length self.x = [None] * (episode_length + 1) # allocating observations self.obs = [None] * episode_length for t in xrange(episode_length): current_data = input_data[:, t, :] self.obs[t] = scg.Constant(value=current_data, shape=[28 * 28 ])(name=VAE.observed_name(t)) # pre-computing features self.features = [] for t in xrange(episode_length): self.features.append(self.rec.get_features(self.obs[t])) for timestep in xrange(episode_length + 1): dummy = True if args.no_dummy and timestep > 0: dummy = False if timestep < episode_length: def rec_query(state): return self._rec_query( input=scg.concat([state, self.features[timestep]])) def rec_strength(state): return self._rec_strength(input=state) rec_response, rec_state = self.set_repr.recognize( self.features, timestep, rec_query, self.num_steps, strength=rec_strength, dummy=dummy) self.z[timestep] = self.rec.recognize( self.features[timestep], scg.concat([rec_response, rec_state]), VAE.hidden_name(timestep)) self.x[timestep] = self.generate(timestep, dummy=dummy)