def tensor_rnn_with_feed_prev(cell,
                              inputs,
                              is_training,
                              config,
                              initial_states=None):
    """High Order Recurrent Neural Network Layer
    """
    #tuple of 2-d tensor (batch_size, s)
    outputs = []
    prev = None
    feed_prev = not is_training if config.use_error_prop else False
    is_sample = is_training and initial_states is not None

    if is_sample:
        print("Creating model @ training  --> Using scheduled sampling.")
    else:
        print("Creating model @ training  --> Not using scheduled sampling.")

    if feed_prev:
        print(' ' * 30 + " --> Feeding output back into input.")
    else:
        print(' ' * 30 + " --> Feeding ground truth into input.")

    with tf.variable_scope("trnn") as varscope:
        if varscope.caching_device is None:
            varscope.set_caching_device(lambda op: op.device)

        inputs_shape = inputs.get_shape().with_rank_at_least(3)
        batch_size = tf.shape(inputs)[0]
        num_steps = inputs_shape[1]
        input_size = int(inputs_shape[2])
        output_size = cell.output_size
        burn_in_steps = config.burn_in_steps

        # Scheduled sampling
        dist = Bernoulli(probs=config.sample_prob)
        samples = dist.sample(sample_shape=num_steps)

        if initial_states is None:
            initial_states = []
            for lag in range(config.num_lags):
                initial_state = cell.zero_state(batch_size, dtype=tf.float32)
                initial_states.append(initial_state)

        states_list = initial_states  #list of high order states

        for time_step in range(num_steps):
            if time_step > 0:
                tf.get_variable_scope().reuse_variables()

            inp = inputs[:, time_step, :]

            if is_sample and time_step > 0:
                with tf.variable_scope(tf.get_variable_scope(), reuse=True):
                    inp = tf.cond(tf.cast(samples[time_step], tf.bool),  lambda:tf.identity(inp) , \
                       lambda:fully_connected(cell_output, input_size, activation_fn=tf.sigmoid))

            if feed_prev and prev is not None and time_step >= burn_in_steps:
                with tf.variable_scope(tf.get_variable_scope(), reuse=True):
                    inp = fully_connected(cell_output,
                                          input_size,
                                          activation_fn=tf.sigmoid)
                    #print("t", time_step, ">=", burn_in_steps, "--> feeding back output into input.")

            states = _list_to_states(states_list)
            """input tensor is [batch_size, num_steps, input_size]"""
            (cell_output, state) = cell(inp, states)

            # dropout
            # keep_prob = tf.placeholder(tf.float32)
            keep_prob = 0.5
            cell_output = tf.nn.dropout(cell_output, keep_prob)

            states_list = _shift(states_list, state)

            prev = cell_output
            with tf.variable_scope(tf.get_variable_scope(), reuse=False):
                output = fully_connected(cell_output,
                                         input_size,
                                         activation_fn=tf.sigmoid)
                outputs.append(output)

    outputs = tf.stack(outputs, 1)
    return outputs, states_list
def rnn_with_feed_prev(cell, inputs, is_training, config, initial_state=None):
    prev = None
    outputs = []
    sample_prob = config.sample_prob  # scheduled sampling probability

    feed_prev = not is_training if config.use_error_prop else False
    is_sample = is_training and initial_state is not None  # decoder

    if is_sample:
        print("Creating model @ training  --> Using scheduled sampling.")
    else:
        print("Creating model @ training  --> Not using scheduled sampling.")

    if feed_prev:
        print(' ' * 30 + " --> Feeding output back into input.")
    else:
        print(' ' * 30 + " --> Feeding ground truth into input.")

    with tf.variable_scope("rnn") as varscope:
        if varscope.caching_device is None:
            varscope.set_caching_device(lambda op: op.device)

        inputs_shape = inputs.get_shape().with_rank_at_least(3)
        batch_size = tf.shape(inputs)[0]
        num_steps = inputs_shape[1]
        input_size = int(inputs_shape[2])
        burn_in_steps = config.burn_in_steps
        output_size = cell.output_size

        # phased lstm input
        inp_t = tf.expand_dims(tf.range(1, batch_size + 1), 1)

        dist = Bernoulli(probs=config.sample_prob)
        samples = dist.sample(sample_shape=num_steps)
        # with tf.Session() as sess:
        #     print('bernoulli',samples.eval())
        if initial_state is None:
            initial_state = cell.zero_state(batch_size, dtype=tf.float32)
        state = initial_state

        for time_step in range(num_steps):
            if time_step > 0:
                tf.get_variable_scope().reuse_variables()

            inp = inputs[:, time_step, :]

            if is_sample and time_step > 0:
                with tf.variable_scope(tf.get_variable_scope(), reuse=True):
                    inp = tf.cond(tf.cast(samples[time_step], tf.bool),  lambda:tf.identity(inp) , \
                       lambda:fully_connected(cell_output, input_size, activation_fn=tf.sigmoid))

            if feed_prev and prev is not None and time_step >= burn_in_steps:
                with tf.variable_scope(tf.get_variable_scope(), reuse=True):
                    inp = fully_connected(prev,
                                          input_size,
                                          activation_fn=tf.sigmoid)
                    #print("t", time_step, ">=", burn_in_steps, "--> feeding back output into input.")

            if isinstance(cell._cells[0], tf.contrib.rnn.PhasedLSTMCell):
                (cell_output, state) = cell((inp_t, inp), state)
            else:
                (cell_output, state) = cell(inp, state)

            prev = cell_output
            with tf.variable_scope(tf.get_variable_scope(), reuse=False):
                output = fully_connected(cell_output,
                                         input_size,
                                         activation_fn=tf.sigmoid)
                outputs.append(output)

    outputs = tf.stack(outputs, 1)
    return outputs, state
class DDPG(RLAlgorithm):
    """
    Deep Deterministic Policy Gradient.
    """

    def __init__(
            self,
            env,
            policy,
            qf,
            es,
            batch_size=32,
            n_epochs=200,
            epoch_length=1000,
            min_pool_size=10000,
            replay_pool_size=1000000,
            replacement_prob=1.0,
            discount=0.99,
            max_path_length=250,
            qf_weight_decay=0.,
            qf_update_method='adam',
            qf_learning_rate=1e-3,
            policy_weight_decay=0,
            policy_update_method='adam',
            policy_learning_rate=1e-3,
            policy_updates_ratio=1.0,
            eval_samples=10000,
            soft_target=True,
            soft_target_tau=0.001,
            n_updates_per_sample=1,
            scale_reward=1.0,
            include_horizon_terminal_transitions=False,
            plot=False,
            pause_for_plot=False,
            **kwargs):
        """
        :param env: Environment
        :param policy: Policy
        :param qf: Q function
        :param es: Exploration strategy
        :param batch_size: Number of samples for each minibatch.
        :param n_epochs: Number of epochs. Policy will be evaluated after each epoch.
        :param epoch_length: How many timesteps for each epoch.
        :param min_pool_size: Minimum size of the pool to start training.
        :param replay_pool_size: Size of the experience replay pool.
        :param discount: Discount factor for the cumulative return.
        :param max_path_length: Discount factor for the cumulative return.
        :param qf_weight_decay: Weight decay factor for parameters of the Q function.
        :param qf_update_method: Online optimization method for training Q function.
        :param qf_learning_rate: Learning rate for training Q function.
        :param policy_weight_decay: Weight decay factor for parameters of the policy.
        :param policy_update_method: Online optimization method for training the policy.
        :param policy_learning_rate: Learning rate for training the policy.
        :param eval_samples: Number of samples (timesteps) for evaluating the policy.
        :param soft_target_tau: Interpolation parameter for doing the soft target update.
        :param n_updates_per_sample: Number of Q function and policy updates per new sample obtained
        :param scale_reward: The scaling factor applied to the rewards when training
        :param include_horizon_terminal_transitions: whether to include transitions with terminal=True because the
        horizon was reached. This might make the Q value back up less stable for certain tasks.
        :param plot: Whether to visualize the policy performance after each eval_interval.
        :param pause_for_plot: Whether to pause before continuing when plotting.
        :return:
        """
        self.env = env
        self.on_policy_env = Serializable.clone(env)
        self.policy = policy
        self.qf = qf
        self.es = es
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.epoch_length = epoch_length
        self.min_pool_size = min_pool_size
        self.replay_pool_size = replay_pool_size
        self.replacement_prob = replacement_prob
        self.discount = discount
        self.max_path_length = max_path_length
        self.qf_weight_decay = qf_weight_decay
        # TODO: replace optimizers with conjugate gradient optimizers with leq constraint as in TRPO
        self.qf_update_method = \
            FirstOrderOptimizer(
                update_method=qf_update_method,
                learning_rate=qf_learning_rate,
            )
        self.qf_learning_rate = qf_learning_rate
        self.policy_weight_decay = policy_weight_decay
        self.policy_update_method = \
            FirstOrderOptimizer(
                update_method=policy_update_method,
                learning_rate=policy_learning_rate,
            )
        self.policy_learning_rate = policy_learning_rate
        self.policy_updates_ratio = policy_updates_ratio
        self.eval_samples = eval_samples
        self.soft_target_tau = soft_target_tau
        self.n_updates_per_sample = n_updates_per_sample
        self.include_horizon_terminal_transitions = include_horizon_terminal_transitions
        self.plot = plot
        self.pause_for_plot = pause_for_plot

        self.qf_loss_averages = []
        self.policy_surr_averages = []
        self.q_averages = []
        self.y_averages = []
        self.paths = []
        self.es_path_returns = []
        self.paths_samples_cnt = 0
        self.random_dist = Bernoulli(None, [.5])
        self.use_gated_sigma = kwargs.get('use_gated_sigma', True)

        self.scale_reward = scale_reward

        self.train_policy_itr = 0

        self.opt_info = None

    def start_worker(self):
        parallel_sampler.populate_task(self.env, self.policy)
        if self.plot:
            plotter.init_plot(self.env, self.policy)

    @overrides
    def train(self):
        gc_dump_time = time.time()
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            # This seems like a rather sequential method
            pool = SimpleReplayPool(
                max_pool_size=self.replay_pool_size,
                observation_dim=self.env.observation_space.flat_dim,
                action_dim=self.env.action_space.flat_dim,
                replacement_prob=self.replacement_prob,
            )


            on_policy_pool = SimpleReplayPool(
                max_pool_size=self.replay_pool_size,
                observation_dim=self.on_policy_env.observation_space.flat_dim,
                action_dim=self.on_policy_env.action_space.flat_dim,
            )


            self.start_worker()

            self.init_opt()
            # This initializes the optimizer parameters
            sess.run(tf.global_variables_initializer())
            itr = 0
            path_length = 0
            path_return = 0
            terminal = False
            initial = False
            observation = self.env.reset()
            on_policy_terminal = False
            on_policy_initial = False
            on_policy_path_length = 0
            on_policy_path_return = 0
            on_policy_observation = self.on_policy_env.reset()


            #with tf.variable_scope("sample_policy"):
                #with suppress_params_loading():
                #sample_policy = pickle.loads(pickle.dumps(self.policy))
            with tf.variable_scope("sample_policy"):
                sample_policy = Serializable.clone(self.policy)

            for epoch in range(self.n_epochs):
                logger.push_prefix('epoch #%d | ' % epoch)
                logger.log("Training started")
                train_qf_itr, train_policy_itr = 0, 0
                for epoch_itr in pyprind.prog_bar(range(self.epoch_length)):
                    # Execute policy
                    if terminal:  # or path_length > self.max_path_length:
                        # Note that if the last time step ends an episode, the very
                        # last state and observation will be ignored and not added
                        # to the replay pool
                        observation = self.env.reset()
                        self.es.reset()
                        sample_policy.reset()
                        self.es_path_returns.append(path_return)
                        path_length = 0
                        path_return = 0
                        initial = True
                    else:
                        initial = False

                    if on_policy_terminal:  # or path_length > self.max_path_length:
                        # Note that if the last time step ends an episode, the very
                        # last state and observation will be ignored and not added
                        # to the replay pool
                        observation = self.on_policy_env.reset()
                        sample_policy.reset()
                        on_policy_path_length = 0
                        on_policy_path_return = 0
                        on_policy_initial = True
                    else:
                        on_policy_initial = False

                    action = self.es.get_action(itr, observation, policy=sample_policy)  # qf=qf)
                    on_policy_action = self.get_action_on_policy(self.on_policy_env, on_policy_observation, policy=sample_policy)

                    next_observation, reward, terminal, _ = self.env.step(action)
                    on_policy_next_observation, on_policy_reward, on_policy_terminal, _ = self.on_policy_env.step(on_policy_action)

                    path_length += 1
                    path_return += reward
                    on_policy_path_length += 1
                    on_policy_path_return += reward

                    if not terminal and path_length >= self.max_path_length:
                        terminal = True
                        # only include the terminal transition in this case if the flag was set
                        if self.include_horizon_terminal_transitions:
                            pool.add_sample(observation, action, reward * self.scale_reward, terminal, initial)
                    else:
                        pool.add_sample(observation, action, reward * self.scale_reward, terminal, initial)

                    if not on_policy_terminal and on_policy_path_length >= self.max_path_length:
                        on_policy_terminal = True
                        # only include the terminal transition in this case if the flag was set
                        if self.include_horizon_terminal_transitions:
                            on_policy_pool.add_sample(on_policy_observation, on_policy_action, on_policy_reward * self.scale_reward, on_policy_terminal, on_policy_initial)
                    else:
                        on_policy_pool.add_sample(on_policy_observation, on_policy_action, on_policy_reward * self.scale_reward,on_policy_terminal, on_policy_initial)


                    on_policy_observation = on_policy_next_observation
                    observation = next_observation

                    if pool.size >= self.min_pool_size:
                        for update_itr in range(self.n_updates_per_sample):
                            # Train policy
                            batch = pool.random_batch(self.batch_size)
                            on_policy_batch = on_policy_pool.random_batch(self.batch_size)
                            itrs = self.do_training(itr, on_policy_batch, batch)
                            train_qf_itr += itrs[0]
                            train_policy_itr += itrs[1]
                        sample_policy.set_param_values(self.policy.get_param_values())

                    itr += 1
                    if time.time() - gc_dump_time > 100:
                        gc.collect()
                        gc_dump_time = time.time()

                logger.log("Training finished")
                logger.log("Trained qf %d steps, policy %d steps"%(train_qf_itr, train_policy_itr))
                if pool.size >= self.min_pool_size:
                    self.evaluate(epoch)
                    params = self.get_epoch_snapshot(epoch)
                    logger.save_itr_params(epoch, params)

                logger.dump_tabular(with_prefix=False)
                logger.pop_prefix()
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                                  "continue...")
            self.env.terminate()
            self.policy.terminate()

    def get_action_on_policy(self, env_spec, observation, policy, **kwargs):
        action, _ = policy.get_action(observation)
        action_space = env_spec.action_space
        return np.clip(action, action_space.low, action_space.high)

    def init_opt(self):

        # First, create "target" policy and Q functions
        with tf.variable_scope("target_policy"):
            target_policy = Serializable.clone(self.policy)
        with tf.variable_scope("target_qf"):
            target_qf = Serializable.clone(self.qf)

        # y need to be computed first
        obs = self.env.observation_space.new_tensor_variable(
            'obs',
            extra_dims=1,
        )

        # The yi values are computed separately as above and then passed to
        # the training functions below
        action = self.env.action_space.new_tensor_variable(
            'action',
            extra_dims=1,
        )

        yvar = tensor_utils.new_tensor(
            'ys',
            ndim=1,
            dtype=tf.float32,
        )

        obs_offpolicy = self.env.observation_space.new_tensor_variable(
            'obs_offpolicy',
            extra_dims=1,
        )


        action_offpolicy = self.env.action_space.new_tensor_variable(
            'action_offpolicy',
            extra_dims=1,
        )


        yvar = tensor_utils.new_tensor(
            'ys',
            ndim=1,
            dtype=tf.float32,
        )


        yvar_offpolicy = tensor_utils.new_tensor(
            'ys_offpolicy',
            ndim=1,
            dtype=tf.float32,
        )

        qf_weight_decay_term = 0.5 * self.qf_weight_decay * \
                               sum([tf.reduce_sum(tf.square(param)) for param in
                                    self.qf.get_params(regularizable=True)])

        qval = self.qf.get_qval_sym(obs, action)
        qval_off = self.qf.get_qval_sym(obs_offpolicy, action_offpolicy)

        qf_loss = tf.reduce_mean(tf.square(yvar - qval))
        qf_loss_off = tf.reduce_mean(tf.square(yvar_offpolicy - qval_off))

        # TODO: penalize dramatic changes in gating_func
        # if PENALIZE_GATING_DISTRIBUTION_DIVERGENCE:


        policy_weight_decay_term = 0.5 * self.policy_weight_decay * \
                                   sum([tf.reduce_sum(tf.square(param))
                                        for param in self.policy.get_params(regularizable=True)])
        policy_qval = self.qf.get_qval_sym(
            obs, self.policy.get_action_sym(obs),
            deterministic=True
        )

        policy_qval_off = self.qf.get_qval_sym(
            obs_offpolicy, self.policy.get_action_sym(obs_offpolicy),
            deterministic=True
        )

        policy_surr = -tf.reduce_mean(policy_qval)
        policy_surr_off = -tf.reduce_mean(policy_qval_off)


        if self.use_gated_sigma:
            print("Using Gated Sigma!")

            input_to_gates= tf.concat([obs, obs_offpolicy], axis=1)

            assert input_to_gates.get_shape().as_list()[-1] == obs.get_shape().as_list()[-1] +  obs_offpolicy.get_shape().as_list()[-1]

            # TODO: right now this is a soft-gate, should make a hard-gate (options vs mixtures)
            gating_func = MLP(name="sigma_gate",
                              output_dim=1,
                              hidden_sizes=(64,64),
                              hidden_nonlinearity=tf.nn.relu,
                              output_nonlinearity=tf.nn.sigmoid,
                              input_var=input_to_gates,
                              input_shape=tuple(input_to_gates.get_shape().as_list()[1:])).output
        else:
            # sample a bernoulli random variable
            print("Using Bernoulli sigma!")
            gating_func = tf.cast(self.random_dist.sample(qf_loss.get_shape()), tf.float32)

        qf_inputs_list = [yvar, obs, action, yvar_offpolicy, obs_offpolicy, action_offpolicy]
        qf_reg_loss = qf_loss*gating_func + qf_loss_off * (1-gating_func) + qf_weight_decay_term

        policy_input_list = [obs, obs_offpolicy]
        policy_reg_surr = policy_surr*gating_func + policy_surr_off*(1-gating_func) + policy_weight_decay_term


        self.qf_update_method.update_opt(qf_reg_loss, target=self.qf, inputs=qf_inputs_list)

        self.policy_update_method.update_opt(policy_reg_surr, target=self.policy, inputs=policy_input_list)


        f_train_qf = tensor_utils.compile_function(
            inputs=qf_inputs_list,
            outputs=[qf_loss, qval, self.qf_update_method._train_op],
        )

        f_train_policy = tensor_utils.compile_function(
            inputs=policy_input_list,
            outputs=[policy_surr, self.policy_update_method._train_op],
        )

        self.opt_info = dict(
            f_train_qf=f_train_qf,
            f_train_policy=f_train_policy,
            target_qf=target_qf,
            target_policy=target_policy,
        )

    def do_training(self, itr, batch, offpolicy_batch):

        obs, actions, rewards, next_obs, terminals = ext.extract(
            batch,
            "observations", "actions", "rewards", "next_observations",
            "terminals"
        )

        obs_off, actions_off, rewards_off, next_obs_off, terminals_off = ext.extract(
            offpolicy_batch,
            "observations", "actions", "rewards", "next_observations",
            "terminals"
        )

        # compute the on-policy y values
        target_qf = self.opt_info["target_qf"]
        target_policy = self.opt_info["target_policy"]

        next_actions, _ = target_policy.get_actions(next_obs)
        next_qvals = target_qf.get_qval(next_obs, next_actions)

        ys = rewards + (1. - terminals) * self.discount * next_qvals.reshape(-1)

        next_actions_off, _ = target_policy.get_actions(next_obs_off)
        next_qvals_off = target_qf.get_qval(next_obs_off, next_actions_off)

        ys_off = rewards + (1. - terminals_off) * self.discount * next_qvals_off.reshape(-1)

        f_train_qf = self.opt_info["f_train_qf"]
        f_train_policy = self.opt_info["f_train_policy"]

        qf_loss, qval, _ = f_train_qf(ys, obs, actions, ys_off, obs_off, actions_off)

        target_qf.set_param_values(
            target_qf.get_param_values() * (1.0 - self.soft_target_tau) +
            self.qf.get_param_values() * self.soft_target_tau)
        self.qf_loss_averages.append(qf_loss)
        self.q_averages.append(qval)
        self.y_averages.append(ys) #TODO: also add ys_off

        self.train_policy_itr += self.policy_updates_ratio
        train_policy_itr = 0
        while self.train_policy_itr > 0:
            f_train_policy = self.opt_info["f_train_policy"]
            policy_surr, _ = f_train_policy(obs, obs_off)
            target_policy.set_param_values(
                target_policy.get_param_values() * (1.0 - self.soft_target_tau) +
                self.policy.get_param_values() * self.soft_target_tau)
            self.policy_surr_averages.append(policy_surr)
            self.train_policy_itr -= 1
            train_policy_itr += 1

        return 1, train_policy_itr # number of itrs qf, policy are trained

    def evaluate(self, epoch):
        paths = parallel_sampler.sample_paths(
            policy_params=self.policy.get_param_values(),
            max_samples=self.eval_samples,
            max_path_length=self.max_path_length,
        )

        average_discounted_return = np.mean(
            [special.discount_return(path["rewards"], self.discount) for path in paths]
        )

        returns = [sum(path["rewards"]) for path in paths]

        all_qs = np.concatenate(self.q_averages)
        all_ys = np.concatenate(self.y_averages)

        average_q_loss = np.mean(self.qf_loss_averages)
        average_policy_surr = np.mean(self.policy_surr_averages)
        average_action = np.mean(np.square(np.concatenate(
            [path["actions"] for path in paths]
        )))

        policy_reg_param_norm = np.linalg.norm(
            self.policy.get_param_values(regularizable=True)
        )
        qfun_reg_param_norm = np.linalg.norm(
            self.qf.get_param_values(regularizable=True)
        )

        logger.record_tabular('Epoch', epoch)
        logger.record_tabular('Iteration', epoch)
        logger.record_tabular('AverageReturn', np.mean(returns))
        logger.record_tabular('StdReturn',
                              np.std(returns))
        logger.record_tabular('MaxReturn',
                              np.max(returns))
        logger.record_tabular('MinReturn',
                              np.min(returns))
        if len(self.es_path_returns) > 0:
            logger.record_tabular('AverageEsReturn',
                                  np.mean(self.es_path_returns))
            logger.record_tabular('StdEsReturn',
                                  np.std(self.es_path_returns))
            logger.record_tabular('MaxEsReturn',
                                  np.max(self.es_path_returns))
            logger.record_tabular('MinEsReturn',
                                  np.min(self.es_path_returns))
        logger.record_tabular('AverageDiscountedReturn',
                              average_discounted_return)
        logger.record_tabular('AverageQLoss', average_q_loss)
        logger.record_tabular('AveragePolicySurr', average_policy_surr)
        logger.record_tabular('AverageQ', np.mean(all_qs))
        logger.record_tabular('AverageAbsQ', np.mean(np.abs(all_qs)))
        logger.record_tabular('AverageY', np.mean(all_ys))
        logger.record_tabular('AverageAbsY', np.mean(np.abs(all_ys)))
        logger.record_tabular('AverageAbsQYDiff',
                              np.mean(np.abs(all_qs - all_ys)))
        logger.record_tabular('AverageAction', average_action)

        logger.record_tabular('PolicyRegParamNorm',
                              policy_reg_param_norm)
        logger.record_tabular('QFunRegParamNorm',
                              qfun_reg_param_norm)

        self.env.log_diagnostics(paths)
        self.policy.log_diagnostics(paths)

        self.qf_loss_averages = []
        self.policy_surr_averages = []

        self.q_averages = []
        self.y_averages = []
        self.es_path_returns = []

    def update_plot(self):
        if self.plot:
            plotter.update_plot(self.policy, self.max_path_length)

    def get_epoch_snapshot(self, epoch):
        return dict(
            env=self.env,
            epoch=epoch,
            qf=self.qf,
            policy=self.policy,
            target_qf=self.opt_info["target_qf"],
            target_policy=self.opt_info["target_policy"],
            es=self.es,
        )
    def compute_gradients(self, loss, var_list=None,
                        gate_gradients=Optimizer.GATE_OP,
                        aggregation_method=None,
                        colocate_gradients_with_ops=False,
                        grad_loss=None):
        """Compute gradients of `loss` for the variables in `var_list`.
        This is the first part of `minimize()`.  It returns a list
        of (gradient, variable) pairs where "gradient" is the gradient
        for "variable".  Note that "gradient" can be a `Tensor`, an
        `IndexedSlices`, or `None` if there is no gradient for the
        given variable.
        Args:
        loss: A Tensor containing the value to minimize or a callable taking
            no arguments which returns the value to minimize. When eager execution
            is enabled it must be a callable.
        var_list: Optional list or tuple of `tf.Variable` to update to minimize
            `loss`.  Defaults to the list of variables collected in the graph
            under the key `GraphKeys.TRAINABLE_VARIABLES`.
        gate_gradients: How to gate the computation of gradients.  Can be
            `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
        aggregation_method: Specifies the method used to combine gradient terms.
            Valid values are defined in the class `AggregationMethod`.
        colocate_gradients_with_ops: If True, try colocating gradients with
            the corresponding op.
        grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
        Returns:
        A list of (gradient, variable) pairs. Variable is always present, but
        gradient can be `None`.
        Raises:
        TypeError: If `var_list` contains anything else than `Variable` objects.
        ValueError: If some arguments are invalid.
        RuntimeError: If called with eager execution enabled and `loss` is
            not callable.
        @compatibility(eager)
        When eager execution is enabled, `gate_gradients`, `aggregation_method`,
        and `colocate_gradients_with_ops` are ignored.
        @end_compatibility
        """
        if callable(loss):
            with backprop.GradientTape() as tape:
                if var_list is not None:
                    tape.watch(var_list)
                    loss_value = loss()

            if var_list is None:
                var_list = tape.watched_variables()
            # TODO(jhseu): Figure out why GradientTape's gradients don't require loss
            # to be executed.
            with ops.control_dependencies([loss_value]):
                grads = tape.gradient(loss_value, var_list, grad_loss)
            return list(zip(grads, var_list))

        # # Non-callable/Tensor loss case
        # if context.executing_eagerly():
        #     raise RuntimeError(
        #         "`loss` passed to Optimizer.compute_gradients should "
        #         "be a function when eager execution is enabled.")

        if gate_gradients not in [SPSA.GATE_NONE, SPSA.GATE_OP, SPSA.GATE_GRAPH]:
            raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, "
                            "Optimizer.GATE_OP, Optimizer.GATE_GRAPH.  Not %s" %
                            gate_gradients)
        self._assert_valid_dtypes([loss])
        if grad_loss is not None:
            self._assert_valid_dtypes([grad_loss])
        if var_list is None:
            var_list = (
                variables.trainable_variables() +
                ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
        else:
            var_list = nest.flatten(var_list)
        # pylint: disable=protected-access
        var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS)
        # pylint: enable=protected-access
        processors = [_get_processor(v) for v in var_list]
        if not var_list:
            raise ValueError("No variables to optimize.")
        var_refs = [p.target() for p in processors]

        # print("var_refs:")
        # for vr in var_refs:
        #     print(vr)

        # ==================================================================================
        # grads = gradients.gradients(
        #     loss, var_refs, grad_ys=grad_loss,
        #     gate_gradients=(gate_gradients == SPSA.GATE_OP),
        #     aggregation_method=aggregation_method,
        #     colocate_gradients_with_ops=colocate_gradients_with_ops)
        
        # grads = [ tf.zeros(tf.shape(vrefs)) for vrefs in var_refs ]
        orig_graph_view = None
        trainable_vars = var_list
        # self.tvars = var_list
        self.tvars = [var.name.split(':')[0] for var in var_list]  # list of names of trainable variables
        self.global_step_tensor = tf.Variable(0, name='global_step', trainable=False)

        # Perturbations
        deltas = {}
        n_perturbations = {}
        p_perturbations = {}
        with tf.name_scope("Perturbator"):
            self.c_t = tf.div( self.c,  tf.pow(tf.add(tf.cast(self.global_step_tensor, tf.float32),
                                              tf.constant(1, dtype=tf.float32)), self.gamma), name = "SPSA_ct" )
            for var in trainable_vars:
                self.num_params += self._mul_dims(var.get_shape())
                var_name = var.name.split(':')[0]
                random = Bernoulli(tf.fill(var.get_shape(), 0.5), dtype=tf.float32)
                deltas[var] = tf.subtract( tf.constant(1, dtype=tf.float32),
                                    tf.scalar_mul(tf.constant(2, dtype=tf.float32),random.sample(1)[0]), name = "SPSA_delta" )
                c_t_delta = tf.scalar_mul( tf.reshape(self.c_t, []), deltas[var] )
                n_perturbations[var_name+'/read:0'] = tf.subtract( var, c_t_delta, name = "perturb_n" )
                p_perturbations[var_name+'/read:0'] = tf.add(var, c_t_delta, name = "perturb_p" )
        # print("{} parameters".format(self.num_params))

        # Evaluator
        with tf.name_scope("Evaluator"):
            orig_graph_view = ge.sgv(tf.get_default_graph())
            _, self.ninfo = self._clone_model(orig_graph_view, n_perturbations, 'N_Eval')
            _, self.pinfo = self._clone_model(orig_graph_view, p_perturbations, 'P_Eval')

        # Weight Updater
        optimizer_ops = []
        grads = []
        with tf.control_dependencies([loss]):
            with tf.name_scope('Updater'):
                a_t = self.a / (tf.pow(tf.add(tf.cast(self.global_step_tensor, tf.float32),
                                             tf.constant(1, dtype=tf.float32)), self.alpha))
                for var in trainable_vars:
                    l_pos = self.pinfo.transformed( loss )
                    l_neg = self.ninfo.transformed( loss )
                    ghat = (l_pos - l_neg) / (tf.constant(2, dtype=tf.float32) * self.c_t * deltas[var])
                    optimizer_ops.append(tf.assign_sub(var, a_t*ghat))
                    grads.append(ghat)
        print(tf.get_default_graph())

        print("grads")
        for g in grads:
            print(g)


        #===================================================================================
        if gate_gradients == SPSA.GATE_GRAPH:
            print("===================")
            grads = control_flow_ops.tuple(grads)
        grads_and_vars = list(zip(grads, var_list))
        self._assert_valid_dtypes(
            [v for g, v in grads_and_vars
            if g is not None and v.dtype != dtypes.resource])
        return grads_and_vars
Exemplo n.º 5
0
    def minimize(self, loss, var_list=None, global_step=None):
        orig_graph_view = None
        trainable_vars = var_list if var_list != None else tf.trainable_variables(
        )
        if self.inputs is not None:
            seed_ops = [t.op for t in self.inputs]
            result = list(seed_ops)
            wave = set(seed_ops)
            while wave:  # stolen from grap_editor.select
                new_wave = set()
                for op in wave:
                    for new_t in op.outputs:
                        if new_t == loss:
                            continue
                        for new_op in new_t.consumers():
                            #if new_op not in result and is_within(new_op):
                            if new_op not in result:
                                new_wave.add(new_op)
                for op in new_wave:
                    if op not in result:
                        result.append(op)
                wave = new_wave
            orig_graph_view = ge.sgv(result)
        else:
            orig_graph_view = ge.sgv(self.work_graph)

        self.global_step_tensor = tf.Variable(
            0, name='global_step',
            trainable=False) if global_step is None else global_step

        # Perturbations
        deltas = {}
        n_perturbations = {}
        p_perturbations = {}
        with tf.name_scope("Perturbator"):
            self.c_t = tf.div(
                self.c,
                tf.pow(
                    tf.add(tf.cast(self.global_step_tensor, tf.float32),
                           tf.constant(1, dtype=tf.float32)), self.gamma),
                name="SPSA_ct")
            # self.c_t = 0.00 #MOD
            for var in trainable_vars:
                self.num_params += self._mul_dims(var.get_shape())
                var_name = var.name.encode('ascii', 'ignore').split(':')[0]
                random = Bernoulli(tf.fill(var.get_shape(), 0.5),
                                   dtype=tf.float32)
                deltas[var] = tf.subtract(tf.constant(1, dtype=tf.float32),
                                          tf.scalar_mul(
                                              tf.constant(2, dtype=tf.float32),
                                              random.sample(1)[0]),
                                          name="SPSA_delta")
                c_t_delta = tf.scalar_mul(tf.reshape(self.c_t, []),
                                          deltas[var])
                n_perturbations[var_name + '/read:0'] = tf.subtract(
                    var, c_t_delta, name="perturb_n")
                p_perturbations[var_name + '/read:0'] = tf.add(
                    var, c_t_delta, name="perturb_p")
        print("{} parameters".format(self.num_params))

        # Evaluator
        with tf.name_scope("Evaluator"):
            _, self.ninfo = self._clone_model(orig_graph_view, n_perturbations,
                                              'N_Eval')
            _, self.pinfo = self._clone_model(orig_graph_view, p_perturbations,
                                              'P_Eval')

        # Weight Updater
        optimizer_ops = []
        with tf.control_dependencies([loss]):
            with tf.name_scope('Updater'):
                a_t = self.a / (tf.pow(
                    tf.add(tf.cast(self.global_step_tensor, tf.float32),
                           tf.constant(1, dtype=tf.float32)), self.alpha))
                # a_t = 0.00 #MOD
                for var in trainable_vars:
                    l_pos = self.pinfo.transformed(loss)
                    l_neg = self.ninfo.transformed(loss)
                    # print( "l_pos: ", l_pos)
                    # print( "l_neg: ", l_neg)
                    ghat = (l_pos - l_neg) / (tf.constant(2, dtype=tf.float32)
                                              * self.c_t * deltas[var])
                    tf.Print(ghat, [l_pos, l_neg], message="L+-:")
                    optimizer_ops.append(tf.assign_sub(var, a_t * ghat))
        grp = control_flow_ops.group(*optimizer_ops)
        with tf.control_dependencies([grp]):
            tf.assign_add(self.global_step_tensor,
                          tf.constant(1, dtype=self.global_step_tensor.dtype))

        return grp
Exemplo n.º 6
0
    def build_model(self):
        with tf.variable_scope('Model', reuse=tf.AUTO_REUSE):
            # Model Feeds
            self.ratings = tf.placeholder(dtype=tf.float32,
                                          shape=[None, self.num_item],
                                          name='ratings')
            # self.output_mask = tf.placeholder(dtype=tf.bool, shape=[None, self.num_item], name='output_mask')
            self.uid = tf.placeholder(dtype=tf.int32,
                                      shape=[None],
                                      name='user_id')

            self.istraining = tf.placeholder(dtype=tf.bool,
                                             shape=[],
                                             name='training_flag')
            self.isnegsample = tf.placeholder(dtype=tf.bool,
                                              shape=[],
                                              name='negative_sample_flag')
            self.layer1_dropout_rate = tf.placeholder(
                dtype=tf.float32, shape=[], name='layer1_dropout_rate')

            # Add Noise to the input ratings
            bernoulli_generator = Bernoulli(probs=self.noise_keep_prob,
                                            dtype=self.ratings.dtype)
            corruption_mask = bernoulli_generator.sample(tf.shape(
                self.ratings))
            corrupted_input = tf.multiply(self.ratings, corruption_mask)

            # Decide the input of the auto-encoder (corrupted at training time and uncorrupted at testing time)
            input = tf.cond(self.istraining, lambda: corrupted_input,
                            lambda: self.ratings)

            # Encoder
            layer1_w = tf.get_variable(
                name='encoder_weights',
                shape=[self.num_item, self.num_factors],
                initializer=tf.truncated_normal_initializer(mean=0.0,
                                                            stddev=0.01))

            layer1_b = tf.get_variable(name='encoder_bias',
                                       shape=[self.num_factors],
                                       initializer=tf.zeros_initializer())

            user_embedding = tf.get_variable(
                name='user_embedding',
                shape=[self.num_user, self.num_factors],
                initializer=tf.truncated_normal_initializer(mean=0.0,
                                                            stddev=0.01))

            # Decoder
            layer2_w = tf.get_variable(
                name='decoder_weights',
                shape=[self.num_factors, self.num_item],
                initializer=tf.truncated_normal_initializer(mean=0.0,
                                                            stddev=0.01))

            layer2_b = tf.get_variable(name='decoder_bias',
                                       shape=[self.num_item],
                                       initializer=tf.zeros_initializer())

            user_node = tf.nn.embedding_lookup(user_embedding, self.uid)
            layer1 = tf.sigmoid(
                tf.matmul(input, layer1_w) + layer1_b + user_node)
            layer1_out = tf.cond(
                self.istraining,
                lambda: tf.layers.dropout(layer1,
                                          rate=self.layer1_dropout_rate,
                                          name='layer1_dropout'),
                lambda: layer1)
            # layer1 = tf.sigmoid(tf.matmul(input, layer1_w) + layer1_b)
            out_vector = tf.identity(
                tf.matmul(layer1_out, layer2_w) + layer2_b)

            # Output
            # Determine whether negative samples should be considered
            # mask = tf.cond(self.isnegsample,
            #                lambda : tf.cast(self.output_mask, dtype=out_vector.dtype),
            #                lambda : tf.sign(self.ratings))
            #
            # self.output = tf.cond(self.istraining,
            #                       lambda : tf.multiply(out_vector, mask),
            #                       lambda : out_vector)

            self.output = out_vector
            self.pred_y = tf.sigmoid(self.output)

            # Loss
            base_loss = tf.reduce_sum(
                tf.nn.sigmoid_cross_entropy_with_logits(labels=input,
                                                        logits=self.output))
            # base_loss = tf.nn.l2_loss(self.pred_y - input)
            base_loss = base_loss / tf.cast(
                tf.shape(input)[0],
                dtype=base_loss.dtype)  # Average over the batches
            reg_loss = self.ae_regs[0] * tf.nn.l2_loss(layer1_w) + self.ae_regs[1] * tf.nn.l2_loss(layer1_b) +\
                       self.ae_regs[2] * tf.nn.l2_loss(layer2_w) + self.ae_regs[3] * tf.nn.l2_loss(layer2_b) +\
                       self.user_node_regs * tf.nn.l2_loss(user_embedding)

            # reg_loss = self.ae_regs[0] * tf.nn.l2_loss(layer1_w) + self.ae_regs[1] * tf.nn.l2_loss(layer1_b) + \
            #            self.ae_regs[2] * tf.nn.l2_loss(layer2_w) + self.ae_regs[3] * tf.nn.l2_loss(layer2_b)
            self.loss = base_loss + reg_loss

            # Optimizer
            self.opt = tf.train.AdagradOptimizer(self.lr).minimize(self.loss)

        print('Model Building Completed.')
Exemplo n.º 7
0
    def fit(self,
            data,
            epochs=1000,
            max_seconds=600,
            activation=tf.nn.elu,
            batch_norm_decay=0.9,
            learning_rate=1e-5,
            batch_sz=1024,
            adapt_lr=False,
            print_progress=True,
            show_fig=True):

        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

        # static features
        X = data['X_train_static_mins']
        N, D = X.shape
        self.X = tf.placeholder(tf.float32, shape=(None, D), name='X')

        # timeseries features
        X_time = data['X_train_time_0']
        T1, N1, D1 = X_time.shape
        assert N == N1
        self.X_time = tf.placeholder(tf.float32,
                                     shape=(T1, None, D1),
                                     name='X_time')
        self.train = tf.placeholder(tf.bool, shape=(), name='train')
        self.rnn_keep_p_encode = tf.placeholder(tf.float32,
                                                shape=(),
                                                name='rnn_keep_p_encode')
        self.rnn_keep_p_decode = tf.placeholder(tf.float32,
                                                shape=(),
                                                name='rnn_keep_p_decode')
        adp_learning_rate = tf.placeholder(tf.float32,
                                           shape=(),
                                           name='adp_learning_rate')

        he_init = variance_scaling_initializer()
        bn_params = {
            'is_training': self.train,
            'decay': batch_norm_decay,
            'updates_collections': None
        }
        latent_size = self.encoder_layer_sizes[-1]

        inputs = self.X
        with tf.variable_scope('static_encoder'):
            for layer_size, keep_p in zip(self.encoder_layer_sizes[:-1],
                                          self.encoder_dropout[:-1]):
                inputs = dropout(inputs, keep_p, is_training=self.train)
                inputs = fully_connected(inputs,
                                         layer_size,
                                         weights_initializer=he_init,
                                         activation_fn=activation,
                                         normalizer_fn=batch_norm,
                                         normalizer_params=bn_params)

        if self.rnn_encoder_layer_sizes:
            with tf.variable_scope('rnn_encoder'):
                rnn_cell = MultiRNNCell([
                    LayerNormBasicLSTMCell(
                        s,
                        activation=tf.tanh,
                        dropout_keep_prob=self.rnn_encoder_dropout)
                    for s in self.rnn_encoder_layer_sizes
                ])
                time_inputs, states = tf.nn.dynamic_rnn(rnn_cell,
                                                        self.X_time,
                                                        swap_memory=True,
                                                        time_major=True,
                                                        dtype=tf.float32)
                time_inputs = tf.transpose(time_inputs, perm=(1, 0, 2))
                time_inputs = tf.reshape(
                    time_inputs,
                    shape=(-1, self.rnn_encoder_layer_sizes[-1] * T1))

            inputs = tf.concat([inputs, time_inputs], axis=1)

        with tf.variable_scope('latent_space'):
            inputs = dropout(inputs,
                             self.encoder_dropout[-1],
                             is_training=self.train)
            loc = fully_connected(inputs,
                                  latent_size,
                                  weights_initializer=he_init,
                                  activation_fn=None,
                                  normalizer_fn=batch_norm,
                                  normalizer_params=bn_params)
            scale = fully_connected(inputs,
                                    latent_size,
                                    weights_initializer=he_init,
                                    activation_fn=tf.nn.softplus,
                                    normalizer_fn=batch_norm,
                                    normalizer_params=bn_params)

            standard_normal = Normal(loc=np.zeros(latent_size,
                                                  dtype=np.float32),
                                     scale=np.ones(latent_size,
                                                   dtype=np.float32))
            e = standard_normal.sample(tf.shape(loc)[0])
            outputs = e * scale + loc

            static_output_size = self.decoder_layer_sizes[0]
            if self.rnn_decoder_layer_sizes:
                time_output_size = self.rnn_decoder_layer_sizes[0] * T1
                output_size = static_output_size + time_output_size
            else:
                output_size = static_output_size
            outputs = fully_connected(outputs,
                                      output_size,
                                      weights_initializer=he_init,
                                      activation_fn=activation,
                                      normalizer_fn=batch_norm,
                                      normalizer_params=bn_params)
            if self.rnn_decoder_layer_sizes:
                outputs, time_outputs = tf.split(
                    outputs, [static_output_size, time_output_size], axis=1)

        with tf.variable_scope('static_decoder'):
            for layer_size, keep_p in zip(self.decoder_layer_sizes,
                                          self.decoder_dropout[:-1]):
                outputs = dropout(outputs, keep_p, is_training=self.train)
                outputs = fully_connected(outputs,
                                          layer_size,
                                          weights_initializer=he_init,
                                          activation_fn=activation,
                                          normalizer_fn=batch_norm,
                                          normalizer_params=bn_params)
            outputs = dropout(outputs,
                              self.decoder_dropout[-1],
                              is_training=self.train)
            outputs = fully_connected(outputs,
                                      D,
                                      weights_initializer=he_init,
                                      activation_fn=None,
                                      normalizer_fn=batch_norm,
                                      normalizer_params=bn_params)

            X_hat = Bernoulli(logits=outputs)
            self.posterior_predictive = X_hat.sample()
            self.posterior_predictive_probs = tf.nn.sigmoid(outputs)

        if self.rnn_decoder_layer_sizes:
            with tf.variable_scope('rnn_decoder'):
                self.rnn_decoder_layer_sizes.append(D1)
                time_output_size = self.rnn_decoder_layer_sizes[0]
                time_outputs = tf.reshape(time_outputs,
                                          shape=(-1, T1, time_output_size))
                time_outputs = tf.transpose(time_outputs, perm=(1, 0, 2))
                rnn_cell = MultiRNNCell([
                    LayerNormBasicLSTMCell(
                        s,
                        activation=tf.tanh,
                        dropout_keep_prob=self.rnn_decoder_dropout)
                    for s in self.rnn_decoder_layer_sizes
                ])
                time_outputs, states = tf.nn.dynamic_rnn(rnn_cell,
                                                         time_outputs,
                                                         swap_memory=True,
                                                         time_major=True,
                                                         dtype=tf.float32)
                time_outputs = tf.transpose(time_outputs, perm=(1, 0, 2))
                time_outputs = tf.reshape(time_outputs, shape=(-1, T1 * D1))
                X_hat_time = Bernoulli(logits=time_outputs)
                posterior_predictive_time = X_hat_time.sample()
                posterior_predictive_time = tf.reshape(
                    posterior_predictive_time, shape=(-1, T1, D1))
                self.posterior_predictive_time = tf.transpose(
                    posterior_predictive_time, perm=(1, 0, 2))
                self.posterior_predictive_probs_time = tf.nn.sigmoid(
                    time_outputs)

        kl_div = -tf.log(scale) + 0.5 * (scale**2 + loc**2) - 0.5
        kl_div = tf.reduce_sum(kl_div, axis=1)

        expected_log_likelihood = tf.reduce_sum(X_hat.log_prob(self.X), axis=1)
        X_time_trans = tf.transpose(self.X_time, perm=(1, 0, 2))
        X_time_reshape = tf.reshape(X_time_trans, shape=(-1, T1 * D1))
        if self.rnn_encoder_layer_sizes:
            expected_log_likelihood_time = tf.reduce_sum(
                X_hat_time.log_prob(X_time_reshape), axis=1)
            elbo = -tf.reduce_sum(expected_log_likelihood +
                                  expected_log_likelihood_time - kl_div)
        else:
            elbo = -tf.reduce_sum(expected_log_likelihood - kl_div)
        train_op = tf.train.AdamOptimizer(
            learning_rate=adp_learning_rate).minimize(elbo)

        tf.summary.scalar('elbo', elbo)
        if self.save_file:
            saver = tf.train.Saver()

        if self.tensorboard:
            for v in tf.trainable_variables():
                tf.summary.histogram(v.name, v)
            train_merge = tf.summary.merge_all()
            writer = tf.summary.FileWriter(self.tensorboard)

        self.init_op = tf.global_variables_initializer()
        n = 0
        n_batches = N // batch_sz
        costs = list()
        min_cost = np.inf

        t0 = dt.now()
        with tf.Session() as sess:
            sess.run(self.init_op)
            for epoch in range(epochs):
                idxs = shuffle(range(N))
                X_train = X[idxs]
                X_train_time = X_time[:, idxs]

                for batch in range(n_batches):
                    n += 1
                    X_batch = X_train[batch * batch_sz:(batch + 1) * batch_sz]
                    X_batch_time = X_train_time[:,
                                                batch * batch_sz:(batch + 1) *
                                                batch_sz]

                    sess.run(train_op,
                             feed_dict={
                                 self.X: X_batch,
                                 self.X_time: X_batch_time,
                                 self.rnn_keep_p_encode:
                                 self.rnn_encoder_dropout,
                                 self.rnn_keep_p_decode:
                                 self.rnn_decoder_dropout,
                                 self.train: True,
                                 adp_learning_rate: learning_rate
                             })
                    if n % 100 == 0 and print_progress:
                        cost = sess.run(elbo,
                                        feed_dict={
                                            self.X: X,
                                            self.X_time: X_time,
                                            self.rnn_keep_p_encode: 1.0,
                                            self.rnn_keep_p_decode: 1.0,
                                            self.train: False
                                        })
                        cost /= N
                        costs.append(cost)

                        if adapt_lr and epoch > 0:
                            if cost < min_cost:
                                min_cost = cost
                            elif cost > min_cost * 1.01:
                                learning_rate *= 0.75
                                if print_progress:
                                    print('Updating Learning Rate',
                                          learning_rate)

                        print('Epoch:', epoch, 'Batch:', batch, 'Cost:', cost)

                        if self.tensorboard:
                            train_sum = sess.run(train_merge,
                                                 feed_dict={
                                                     self.X: X,
                                                     self.X_time: X_time,
                                                     self.rnn_keep_p_encode:
                                                     1.0,
                                                     self.rnn_keep_p_decode:
                                                     1.0,
                                                     self.train: False
                                                 })
                            writer.add_summary(train_sum, n)

                seconds = (dt.now() - t0).seconds
                if seconds > max_seconds:
                    if print_progress:
                        print('Breaking after', seconds, 'seconds')
                    break

            if self.save_file:
                saver.save(sess, self.save_file)

            if self.tensorboard:
                writer.add_graph(sess.graph)

        if show_fig:
            plt.plot(costs)
            plt.title('Costs and Scores')
            plt.show()
Exemplo n.º 8
0
    def _build(self, inpt, state):

        img_flat, canvas_flat, what_code, where_code, hidden_state, presence = state
        img = tf.reshape(img_flat, (-1, ) + tuple(self._img_size))

        inpt_encoding = img
        inpt_encoding = self._input_encoder(inpt_encoding)

        with tf.variable_scope('rnn_inpt'):
            rnn_inpt = tf.concat(
                (inpt_encoding, what_code, where_code, presence), -1)
            rnn_inpt = self._rnn_projection(rnn_inpt)
            hidden_output, hidden_state = self._transition(
                rnn_inpt, hidden_state)

        where_param = self._transform_estimator(hidden_output)
        where_distrib = NormalWithSoftplusScale(
            *where_param,
            validate_args=self._debug,
            allow_nan_stats=not self._debug)
        where_loc, where_scale = where_distrib.loc, where_distrib.scale
        where_code = where_distrib.sample()

        cropped = self._spatial_transformer(img, where_code)

        with tf.variable_scope('presence'):
            presence_prob = self._steps_predictor(hidden_output)

            if self._explore_eps is not None:
                clipped_prob = tf.clip_by_value(presence_prob,
                                                self._explore_eps,
                                                1. - self._explore_eps)
                presence_prob = tf.stop_gradient(clipped_prob -
                                                 presence_prob) + presence_prob

            if self._sample_presence:
                presence_distrib = Bernoulli(probs=presence_prob,
                                             dtype=tf.float32,
                                             validate_args=self._debug,
                                             allow_nan_stats=not self._debug)

                new_presence = presence_distrib.sample()
                presence *= new_presence

            else:
                presence = presence_prob

        what_params = self._glimpse_encoder(cropped)
        what_distrib = self._what_distrib(what_params)
        what_loc, what_scale = what_distrib.loc, what_distrib.scale
        what_code = what_distrib.sample()
        decoded = self._glimpse_decoder(
            tf.concat([what_code, tf.stop_gradient(where_code)], -1))
        inversed = self._inverse_transformer(decoded, where_code)

        with tf.variable_scope('rnn_outputs'):
            inversed_flat = tf.reshape(inversed, (-1, self._n_pix))

            canvas_flat = canvas_flat + presence * inversed_flat  # * novelty_flat
            decoded_flat = tf.reshape(decoded, (-1, np.prod(self._crop_size)))

        output = [
            canvas_flat, decoded_flat, what_code, what_loc, what_scale,
            where_code, where_loc, where_scale, presence_prob, presence
        ]
        state = [
            img_flat, canvas_flat, what_code, where_code, hidden_state,
            presence
        ]
        return output, state
Exemplo n.º 9
0
    def build_model(self):
        with tf.variable_scope('TDAE_Model', reuse=tf.AUTO_REUSE):
            self.rating = tf.placeholder(dtype=tf.float32,
                                         shape=[None, self.num_item],
                                         name='rating')
            self.trust = tf.placeholder(dtype=tf.float32,
                                        shape=[None, self.num_user],
                                        name='trust')
            Theta0 = tf.get_variable(
                name='Theta0',
                shape=[self.num_factors],
                initializer=tf.truncated_normal_initializer(mean=0,
                                                            stddev=0.03))
            Theta1 = tf.get_variable(
                name='Theta1',
                shape=[self.num_factors],
                initializer=tf.truncated_normal_initializer(mean=0,
                                                            stddev=0.03))
            # Corrupted
            R_berngen = Bernoulli(probs=self.q, dtype=self.rating.dtype)
            Rcorrupt_mask = R_berngen.sample(tf.shape(self.rating))
            Rcorrupt_input = tf.multiply(self.rating, Rcorrupt_mask)
            T_berngen = Bernoulli(probs=self.q, dtype=self.trust.dtype)
            Tcorrupt_mask = T_berngen.sample(tf.shape(self.trust))
            Tcorrupt_input = tf.multiply(self.trust, Tcorrupt_mask)
            # Encoder
            Rlayer1_w = tf.get_variable(
                name='Re_weights',
                shape=[self.num_item, self.num_factors],
                initializer=tf.truncated_normal_initializer(mean=0,
                                                            stddev=0.03))
            Rlayer1_b = tf.get_variable(name='Re_bias',
                                        shape=[self.num_factors],
                                        initializer=tf.zeros_initializer())

            Tlayer1_w = tf.get_variable(
                name='Te_weights',
                shape=[self.num_user, self.num_factors],
                initializer=tf.truncated_normal_initializer(mean=0,
                                                            stddev=0.03))
            Tlayer1_b = tf.get_variable(name='Te_bias',
                                        shape=[self.num_factors],
                                        initializer=tf.zeros_initializer())

            Rlayer1 = tf.sigmoid(
                tf.matmul(Rcorrupt_input, Rlayer1_w) + Rlayer1_b)
            Tlayer1 = tf.sigmoid(
                tf.matmul(Tcorrupt_input, Tlayer1_w) + Tlayer1_b)
            layerP = self.alpha * Rlayer1 + (1 - self.alpha) * Tlayer1
            # Decoder
            Rlayer2_w = tf.get_variable(
                name='Rd_weights',
                shape=[self.num_factors, self.num_item],
                initializer=tf.truncated_normal_initializer(mean=0.0,
                                                            stddev=0.03))
            Rlayer2_b = tf.get_variable(name='Rd_bias',
                                        shape=[self.num_item],
                                        initializer=tf.zeros_initializer())

            Tlayer2_w = tf.get_variable(
                name='Td_weights',
                shape=[self.num_factors, self.num_user],
                initializer=tf.truncated_normal_initializer(mean=0.0,
                                                            stddev=0.03))
            Tlayer2_b = tf.get_variable(name='Td_bias',
                                        shape=[self.num_user],
                                        initializer=tf.zeros_initializer())
            self.Pred_R = tf.sigmoid(tf.matmul(layerP, Rlayer2_w) + Rlayer2_b)
            self.Pred_T = tf.sigmoid(tf.matmul(layerP, Tlayer2_w) + Tlayer2_b)
            #loss
            loss1 = tf.reduce_sum(
                tf.nn.sigmoid_cross_entropy_with_logits(labels=self.rating,
                                                        logits=self.Pred_R))
            loss2 = tf.reduce_sum(
                tf.nn.sigmoid_cross_entropy_with_logits(labels=self.trust,
                                                        logits=self.Pred_T))
            loss3 = self.lambdaT * (
                tf.nn.l2_loss(Rlayer1_w) + tf.nn.l2_loss(Rlayer1_b) +
                tf.nn.l2_loss(Tlayer1_w) + tf.nn.l2_loss(Tlayer1_b) +
                tf.nn.l2_loss(Tlayer2_b) + tf.nn.l2_loss(Tlayer2_w) +
                tf.nn.l2_loss(Rlayer2_b) + tf.nn.l2_loss(Rlayer2_w))
            loss4 = tf.nn.l2_loss(Theta0) + tf.nn.l2_loss(Theta1)
            loss5 = 2 * (
                tf.nn.l2_loss(Rlayer1 - tf.multiply(Tlayer1, Theta0)) +
                tf.nn.l2_loss(Tlayer1 - tf.multiply(Rlayer1, Theta1)))
            self.loss = loss1 + loss2 + loss3 + loss4 + loss5
            self.opt = tf.train.AdamOptimizer(self.lr).minimize(self.loss)
        return self