コード例 #1
0
        def spatial_loss(truth_features, predicted_features, space_desc):
            feature_losses = []
            for truth, prediction, spec in zip(truth_features,
                                               predicted_features,
                                               space_desc.features):
                if spec.type == FeatureType.CATEGORICAL:
                    truth = tf.transpose(truth, (0, 2, 3, 1))
                    prediction = tf.transpose(prediction, (0, 2, 3, 1))
                    feature_losses.append(
                        tf.losses.softmax_cross_entropy(truth, prediction))

                    summary_image = tf.argmax(
                        tf.concat([truth, prediction], 2), 3)
                    summary_image = tf.gather(
                        palette[space_desc.index][spec.index], summary_image)
                    tf.summary.image(spec.name, summary_image)
                else:
                    feature_losses.append(
                        tf.losses.mean_squared_error(truth, prediction))

                    summary_image = tf.concat([truth, prediction], 3)
                    tf.summary.image(spec.name,
                                     tf.transpose(summary_image, (0, 2, 3, 1)))

                tf.summary.scalar(spec.name, feature_losses[-1])

            return tf.reduce_mean(tf.stack(feature_losses))
コード例 #2
0
  def _build_target_distribution(self):
    self._reshape_networks()
    batch_size = tf.shape(self._replay.rewards)[0]
    # size of rewards: batch_size x 1
    rewards = self._replay.rewards[:, None]
    # size of tiled_support: batch_size x num_atoms
    tiled_support = tf.tile(self.support, [batch_size])
    tiled_support = tf.reshape(tiled_support, [batch_size, self.num_atoms])
    # size of target_support: batch_size x num_atoms

    is_terminal_multiplier = 1. - tf.cast(self._replay.terminals, tf.float32)
    # Incorporate terminal state to discount factor.
    # size of gamma_with_terminal: batch_size x 1
    gamma_with_terminal = self.cumulative_gamma * is_terminal_multiplier
    gamma_with_terminal = gamma_with_terminal[:, None]

    target_support = rewards + gamma_with_terminal * tiled_support
    # size of next_probabilities: batch_size  x num_actions x num_atoms
    next_probabilities = tf.contrib.layers.softmax(
        self._replay_next_logits)

    # size of next_qt: 1 x num_actions
    next_qt = tf.reduce_sum(self.support * next_probabilities, 2)
    # size of next_qt_argmax: 1 x batch_size
    next_qt_argmax = tf.argmax(
        next_qt + self._replay.next_legal_actions, axis=1)[:, None]
    batch_indices = tf.range(tf.to_int64(batch_size))[:, None]
    # size of next_qt_argmax: batch_size x 2
    next_qt_argmax = tf.concat([batch_indices, next_qt_argmax], axis=1)
    # size of next_probabilities: batch_size x num_atoms
    next_probabilities = tf.gather_nd(next_probabilities, next_qt_argmax)
    return project_distribution(target_support, next_probabilities,
                                self.support)
コード例 #3
0
    def _build_networks(self):
        """Builds the Q-value network computations needed for acting and training.

    These are:
      self.online_convnet: For computing the current state's Q-values.
      self.target_convnet: For computing the next state's target Q-values.
      self._net_outputs: The actual Q-values.
      self._q_argmax: The action maximizing the current state's Q-values.
      self._replay_net_outputs: The replayed states' Q-values.
      self._replay_target_net_outputs: The replayed states' target Q-values.
      self._replay_next_target_net_outputs: The replayed next states' target
        Q-values (see Mnih et al., 2015 for details).
    """

        # _network_template instantiates the model and returns the network object.
        # The network object can be used to generate different outputs in the graph.
        # At each call to the network, the parameters will be reused.
        self.online_convnet = self._create_network(name='Online')
        self.target_convnet = self._create_network(name='Target')
        self._net_outputs = self.online_convnet(self.state_ph)
        self._q_argmax = tf.argmax(self._net_outputs.q_values, axis=1)[0]
        self._replay_net_outputs = self.online_convnet(self._replay.states)
        self._replay_next_net_outputs = self.online_convnet(
            self._replay.next_states)

        self._replay_target_net_outputs = self.target_convnet(
            self._replay.states)
        self._replay_next_target_net_outputs = self.target_convnet(
            self._replay.next_states)
コード例 #4
0
ファイル: full_slate_q_agent.py プロジェクト: yokaiAG/recsim
 def _build_networks(self):
   with tf.compat.v1.name_scope('networks'):
     self._replay_net_outputs = self._network_adapter(self._replay.states,
                                                      'Online')
     self._replay_next_target_net_outputs = self._network_adapter(
         self._replay.states, 'Target')
     self._net_outputs = self._network_adapter(self.state_ph, 'Online')
     self._q_argmax = tf.argmax(input=self._net_outputs.q_values, axis=1)[0]
コード例 #5
0
ファイル: base.py プロジェクト: pandeydeep9/meta-dataset
  def compute_accuracy(self, onehot_labels, predictions):
    """Computes the accuracy of `predictions` with respect to `onehot_labels`.

    Args:
      onehot_labels: A `tf.Tensor` containing the the class labels; each vector
        along the (last) class dimension is expected to contain only a single
        `1`.
      predictions: A `tf.Tensor` containing the the class predictions
        represented as unnormalized log probabilities.

    Returns:
       A `tf.Tensor` of ones and zeros representing the correctness of
       individual predictions; use `tf.reduce_mean(...)` to obtain the average
       accuracy.
    """
    correct = tf.equal(tf.argmax(onehot_labels, -1), tf.argmax(predictions, -1))
    return tf.cast(correct, tf.float32)
コード例 #6
0
ファイル: base.py プロジェクト: Milkigit/meta-dataset
def class_specific_data(onehot_labels, data, num_classes, axis=0):
    # TODO(eringrant): Deal with case of no data for a class in [1...num_classes].
    data_shape = [s for i, s in enumerate(data.shape) if i != axis]
    labels = tf.argmax(onehot_labels, axis=-1)
    class_idx = [tf.where(tf.equal(labels, i)) for i in range(num_classes)]
    return [
        tf.reshape(tf.gather(data, idx, axis=axis), [-1] + data_shape)
        for idx in class_idx
    ]
コード例 #7
0
    def compute_logits_for_episode(self, support_embeddings, query_embeddings,
                                   data):
        """Compute CrossTransformer logits."""
        with tf.variable_scope('tformer_keys', reuse=tf.AUTO_REUSE):
            support_keys, key_params = functional_backbones.conv(
                support_embeddings, [1, 1],
                self.query_dim,
                1,
                weight_decay=self.tformer_weight_decay)
            query_queries, _ = functional_backbones.conv(
                query_embeddings, [1, 1],
                self.query_dim,
                1,
                params=key_params,
                weight_decay=self.tformer_weight_decay)

        with tf.variable_scope('tformer_values', reuse=tf.AUTO_REUSE):
            support_values, value_params = functional_backbones.conv(
                support_embeddings, [1, 1],
                self.val_dim,
                1,
                weight_decay=self.tformer_weight_decay)
            query_values, _ = functional_backbones.conv(
                query_embeddings, [1, 1],
                self.val_dim,
                1,
                params=value_params,
                weight_decay=self.tformer_weight_decay)

        onehot_support_labels = distribute_utils.aggregate(
            data.onehot_support_labels)
        support_keys = distribute_utils.aggregate(support_keys)
        support_values = distribute_utils.aggregate(support_values)

        labels = tf.argmax(onehot_support_labels, axis=1)
        if self.rematerialize:
            distances = self._get_dist_rematerialize(query_queries,
                                                     query_values,
                                                     support_keys,
                                                     support_values, labels)
        else:
            distances = self._get_dist(query_queries, query_values,
                                       support_keys, support_values, labels)

        self.test_logits = -tf.transpose(distances)

        return self.test_logits
コード例 #8
0
  def _reshape_networks(self):
    # self._q is actually logits now, rename things.
    # size of _logits: 1 x num_actions x num_atoms
    self._logits = self._q
    # size of _probabilities: 1 x num_actions x num_atoms
    self._probabilities = tf.contrib.layers.softmax(self._q)
    # size of _q: 1 x num_actions
    self._q = tf.reduce_sum(self.support * self._probabilities, axis=2)
    # Recompute argmax from q values. Ignore illegal actions.
    self._q_argmax = tf.argmax(self._q + self.legal_actions_ph, axis=1)[0]

    # size of _replay_logits: 1 x num_actions x num_atoms
    self._replay_logits = self._replay_qs
    # size of _replay_next_logits: 1 x num_actions x num_atoms
    self._replay_next_logits = self._replay_next_qt
    del self._replay_qs
    del self._replay_next_qt
コード例 #9
0
def select_slate_optimal(slate_size, s_no_click, s, q):
    """Selects the slate using exhaustive search.

  This algorithm corresponds to the method "OS" in
  Ie et al. https://arxiv.org/abs/1905.12767.

  Args:
    slate_size: int, the size of the recommendation slate.
    s_no_click: float tensor, the score for not clicking any document.
    s: [num_of_documents] tensor, the scores for clicking documents.
    q: [num_of_documents] tensor, the predicted q values for documents.

  Returns:
    [slate_size] tensor, the selected slate.
  """

    num_candidates = s.shape.as_list()[0]

    # Obtain all possible slates given current docs in the candidate set.
    mesh_args = [list(range(num_candidates))] * slate_size
    slates = tf.stack(tf.meshgrid(*mesh_args), axis=-1)
    slates = tf.reshape(slates, shape=(-1, slate_size))

    # Filter slates that include duplicates to ensure each document is picked
    # at most once.
    unique_mask = tf.map_fn(
        lambda x: tf.equal(tf.size(input=x), tf.size(input=tf.unique(x)[0])),
        slates,
        dtype=tf.bool)
    slates = tf.boolean_mask(tensor=slates, mask=unique_mask)

    slate_q_values = tf.gather(s * q, slates)
    slate_scores = tf.gather(s, slates)
    slate_normalizer = tf.reduce_sum(input_tensor=slate_scores,
                                     axis=1) + s_no_click

    slate_q_values = slate_q_values / tf.expand_dims(slate_normalizer, 1)
    slate_sum_q_values = tf.reduce_sum(input_tensor=slate_q_values, axis=1)
    max_q_slate_index = tf.argmax(input=slate_sum_q_values)
    return tf.gather(slates, max_q_slate_index, axis=0)
コード例 #10
0
    def inner_objective(self, onehot_labels, predictions, iteration_idx):
        """Compute the inner-loop objective."""
        # p(z, y), joint log-likelihood.
        joint_log_probs = self.joint_log_likelihood(onehot_labels, predictions)
        labels = tf.expand_dims(tf.argmax(input=onehot_labels, axis=-1),
                                axis=-1)
        numerator = tf.gather(joint_log_probs, labels, axis=-1, batch_dims=1)

        # p(z), normalization constant.
        evidence = tf.reduce_logsumexp(input_tensor=joint_log_probs,
                                       axis=-1,
                                       keepdims=True)

        # p(y | z) if interpolation coefficient > 0 else p(z, y).
        # TODO(eringrant): This assumes that `interp` is either 1 or 0.
        # Adapt to a hybridized approach.
        interp = tf.gather(self.gen_disc_interpolation, iteration_idx)
        scale = tf.cond(pred=interp > 0.0,
                        true_fn=lambda: 1.0,
                        false_fn=lambda: self.generative_scaling)

        return -scale * tf.reduce_mean(
            input_tensor=numerator - interp * evidence, axis=0)
コード例 #11
0
def _get_class_labels_and_predictions(labels, logits, num_classes,
                                      multi_label):
    """Returns list of per-class-labels and list of per-class-predictions.

  Args:
    labels: A `Tensor` of size [n, k]. In the
      multi-label case, values are either 0 or 1 and k = num_classes. Otherwise,
      k = 1 and values are in [0, num_classes).
    logits: A `Tensor` of size [n, `num_classes`]
      representing the logits of each pixel and semantic class.
    num_classes: Number of classes.
    multi_label: Boolean which defines if we are in a multi_label setting, where
      pixels can have multiple labels, or not.

  Returns:
    class_labels: List of size num_classes, where each entry is a `Tensor' of
      size [batch_size, height, width] of type float with values of 0 or 1
      representing the ground truth labels.
    class_predictions: List of size num_classes, each entry is a `Tensor' of
      size [batch_size, height, width] of type float with values of 0 or 1
      representing the predicted labels.
  """
    class_predictions = [None] * num_classes
    if multi_label:
        class_labels = tf.split(labels, num_or_size_splits=num_classes, axis=1)
        class_logits = tf.split(logits, num_or_size_splits=num_classes, axis=1)
        for c in range(num_classes):
            class_predictions[c] = tf.cast(tf.greater(class_logits[c], 0),
                                           dtype=tf.float32)
    else:
        class_predictions_flat = tf.argmax(logits, 1)
        class_labels = [None] * num_classes
        for c in range(num_classes):
            class_labels[c] = tf.cast(tf.equal(labels, c), dtype=tf.float32)
            class_predictions[c] = tf.cast(tf.equal(class_predictions_flat, c),
                                           dtype=tf.float32)
    return class_labels, class_predictions
コード例 #12
0
 def argmax(v, mask):
     return tf.argmax(input=(v - tf.reduce_min(input_tensor=v) + 1) * mask,
                      axis=0)
コード例 #13
0
    def __init__(self,
                 num_actions=None,
                 observation_size=None,
                 num_players=None,
                 gamma=0.99,
                 update_horizon=1,
                 min_replay_history=500,
                 update_period=4,
                 stack_size=1,
                 target_update_period=500,
                 epsilon_fn=linearly_decaying_epsilon,
                 epsilon_train=0.02,
                 epsilon_eval=0.001,
                 epsilon_decay_period=1000,
                 graph_template=dqn_template,
                 tf_device='/cpu:*',
                 use_staging=True,
                 optimizer=tf.train.RMSPropOptimizer(learning_rate=.0025,
                                                     decay=0.95,
                                                     momentum=0.0,
                                                     epsilon=1e-6,
                                                     centered=True)):
        """Initializes the agent and constructs its graph.

    Args:
      num_actions: int, number of actions the agent can take at any state.
      observation_size: int, size of observation vector.
      num_players: int, number of players playing this game.
      gamma: float, discount factor as commonly used in the RL literature.
      update_horizon: int, horizon at which updates are performed, the 'n' in
        n-step update.
      min_replay_history: int, number of stored transitions before training.
      update_period: int, period between DQN updates.
      stack_size: int, number of observations to use as state.
      target_update_period: Update period for the target network.
      epsilon_fn: Function expecting 4 parameters: (decay_period, step,
        warmup_steps, epsilon), and which returns the epsilon value used for
        exploration during training.
      epsilon_train: float, final epsilon for training.
      epsilon_eval: float, epsilon during evaluation.
      epsilon_decay_period: int, number of steps for epsilon to decay.
      graph_template: function for building the neural network graph.
      tf_device: str, Tensorflow device on which to run computations.
      use_staging: bool, when True use a staging area to prefetch the next
        sampling batch.
      optimizer: Optimizer instance used for learning.
    """

        self.partial_reload = False

        tf.logging.info('Creating %s agent with the following parameters:',
                        self.__class__.__name__)
        tf.logging.info('\t gamma: %f', gamma)
        tf.logging.info('\t update_horizon: %f', update_horizon)
        tf.logging.info('\t min_replay_history: %d', min_replay_history)
        tf.logging.info('\t update_period: %d', update_period)
        tf.logging.info('\t target_update_period: %d', target_update_period)
        tf.logging.info('\t epsilon_train: %f', epsilon_train)
        tf.logging.info('\t epsilon_eval: %f', epsilon_eval)
        tf.logging.info('\t epsilon_decay_period: %d', epsilon_decay_period)
        tf.logging.info('\t tf_device: %s', tf_device)
        tf.logging.info('\t use_staging: %s', use_staging)
        tf.logging.info('\t optimizer: %s', optimizer)

        # Global variables.
        self.num_actions = num_actions
        self.observation_size = observation_size
        self.num_players = num_players
        self.gamma = gamma
        self.update_horizon = update_horizon
        self.cumulative_gamma = math.pow(gamma, update_horizon)
        self.min_replay_history = min_replay_history
        self.target_update_period = target_update_period
        self.epsilon_fn = epsilon_fn
        self.epsilon_train = epsilon_train
        self.epsilon_eval = epsilon_eval
        self.epsilon_decay_period = epsilon_decay_period
        self.update_period = update_period
        self.eval_mode = False
        self.training_steps = 0
        self.batch_staged = False
        self.optimizer = optimizer

        with tf.device(tf_device):
            # Calling online_convnet will generate a new graph as defined in
            # graph_template using whatever input is passed, but will always share
            # the same weights.
            online_convnet = tf.make_template('Online', graph_template)
            target_convnet = tf.make_template('Target', graph_template)
            # The state of the agent. The last axis is the number of past observations
            # that make up the state.
            states_shape = (1, observation_size, stack_size)
            self.state = np.zeros(states_shape)
            self.state_ph = tf.placeholder(tf.uint8,
                                           states_shape,
                                           name='state_ph')
            self.legal_actions_ph = tf.placeholder(tf.float32,
                                                   [self.num_actions],
                                                   name='legal_actions_ph')
            self._q = online_convnet(state=self.state_ph,
                                     num_actions=self.num_actions)
            self._replay = self._build_replay_memory(use_staging)
            self._replay_qs = online_convnet(self._replay.states,
                                             self.num_actions)
            self._replay_next_qt = target_convnet(self._replay.next_states,
                                                  self.num_actions)
            self._train_op = self._build_train_op()
            self._sync_qt_ops = self._build_sync_op()

            self._q_argmax = tf.argmax(self._q + self.legal_actions_ph,
                                       axis=1)[0]

        # Set up a session and initialize variables.
        self._sess = tf.Session(
            '', config=tf.ConfigProto(allow_soft_placement=True))
        self._init_op = tf.global_variables_initializer()
        self._sess.run(self._init_op)

        self._saver = tf.train.Saver(max_to_keep=3)

        # This keeps tracks of the observed transitions during play, for each
        # player.
        self.transitions = [[] for _ in range(num_players)]
コード例 #14
0
def train(hparams, num_epoch, tuning):

    log_dir = './results/'
    test_batch_size = 8
    # Load dataset
    training_set, valid_set = make_dataset(BATCH_SIZE=hparams['HP_BS'],
                                           file_name='train_tf_record',
                                           split=True)
    test_set = make_dataset(BATCH_SIZE=test_batch_size,
                            file_name='test_tf_record',
                            split=False)
    class_names = ['NRDR', 'RDR']

    # Model
    model = ResNet()

    # set optimizer
    optimizer = tf.keras.optimizers.Adam(learning_rate=hparams['HP_LR'])
    # set metrics
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
    valid_accuracy = tf.keras.metrics.Accuracy()
    valid_con_mat = ConfusionMatrix(num_class=2)
    test_accuracy = tf.keras.metrics.Accuracy()
    test_con_mat = ConfusionMatrix(num_class=2)

    # Save Checkpoint
    if not tuning:
        ckpt = tf.train.Checkpoint(step=tf.Variable(1),
                                   optimizer=optimizer,
                                   net=model)
        manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=5)

    # Set up summary writers
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    tb_log_dir = log_dir + current_time + '/train'
    summary_writer = tf.summary.create_file_writer(tb_log_dir)

    # Restore Checkpoint
    if not tuning:
        ckpt.restore(manager.latest_checkpoint)
        if manager.latest_checkpoint:
            logging.info('Restored from {}'.format(manager.latest_checkpoint))
        else:
            logging.info('Initializing from scratch.')

    @tf.function
    def train_step(train_img, train_label):
        # Optimize the model
        loss_value, grads = grad(model, train_img, train_label)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        train_pred, _ = model(train_img)
        train_label = tf.expand_dims(train_label, axis=1)
        train_accuracy.update_state(train_label, train_pred)

    for epoch in range(num_epoch):

        begin = time()

        # Training loop
        for train_img, train_label, train_name in training_set:
            train_img = data_augmentation(train_img)
            train_step(train_img, train_label)

        with summary_writer.as_default():
            tf.summary.scalar('Train Accuracy',
                              train_accuracy.result(),
                              step=epoch)

        for valid_img, valid_label, _ in valid_set:
            valid_img = tf.cast(valid_img, tf.float32)
            valid_img = valid_img / 255.0
            valid_pred, _ = model(valid_img, training=False)
            valid_pred = tf.cast(tf.argmax(valid_pred, axis=1), dtype=tf.int64)
            valid_con_mat.update_state(valid_label, valid_pred)
            valid_accuracy.update_state(valid_label, valid_pred)

        # Log the confusion matrix as an image summary
        cm_valid = valid_con_mat.result()
        figure = plot_confusion_matrix(cm_valid, class_names=class_names)
        cm_valid_image = plot_to_image(figure)

        with summary_writer.as_default():
            tf.summary.scalar('Valid Accuracy',
                              valid_accuracy.result(),
                              step=epoch)
            tf.summary.image('Valid ConfusionMatrix',
                             cm_valid_image,
                             step=epoch)

        end = time()
        logging.info(
            "Epoch {:d} Training Accuracy: {:.3%} Validation Accuracy: {:.3%} Time:{:.5}s"
            .format(epoch + 1, train_accuracy.result(),
                    valid_accuracy.result(), (end - begin)))
        train_accuracy.reset_states()
        valid_accuracy.reset_states()
        valid_con_mat.reset_states()
        if not tuning:
            if int(ckpt.step) % 5 == 0:
                save_path = manager.save()
                logging.info('Saved checkpoint for epoch {}: {}'.format(
                    int(ckpt.step), save_path))
            ckpt.step.assign_add(1)

    for test_img, test_label, _ in test_set:
        test_img = tf.cast(test_img, tf.float32)
        test_img = test_img / 255.0
        test_pred, _ = model(test_img, training=False)
        test_pred = tf.cast(tf.argmax(test_pred, axis=1), dtype=tf.int64)
        test_accuracy.update_state(test_label, test_pred)
        test_con_mat.update_state(test_label, test_pred)

    cm_test = test_con_mat.result()
    # Log the confusion matrix as an image summary
    figure = plot_confusion_matrix(cm_test, class_names=class_names)
    cm_test_image = plot_to_image(figure)
    with summary_writer.as_default():
        tf.summary.scalar('Test Accuracy', test_accuracy.result(), step=epoch)
        tf.summary.image('Test ConfusionMatrix', cm_test_image, step=epoch)

    logging.info("Trained finished. Final Accuracy in test set: {:.3%}".format(
        test_accuracy.result()))

    # Visualization
    if not tuning:
        for vis_img, vis_label, vis_name in test_set:
            vis_label = vis_label[0]
            vis_name = vis_name[0]
            vis_img = tf.cast(vis_img[0], tf.float32)
            vis_img = tf.expand_dims(vis_img, axis=0)
            vis_img = vis_img / 255.0
            with tf.GradientTape() as tape:
                vis_pred, conv_output = model(vis_img, training=False)
                pred_label = tf.argmax(vis_pred, axis=-1)
                vis_pred = tf.reduce_max(vis_pred, axis=-1)
                grad_1 = tape.gradient(vis_pred, conv_output)
                weight = tf.reduce_mean(grad_1, axis=[1, 2]) / grad_1.shape[1]
                act_map0 = tf.nn.relu(
                    tf.reduce_sum(weight * conv_output, axis=-1))
                act_map0 = tf.squeeze(tf.image.resize(tf.expand_dims(act_map0,
                                                                     axis=-1),
                                                      (256, 256),
                                                      antialias=True),
                                      axis=-1)
                plot_map(vis_img, act_map0, vis_pred, pred_label, vis_label,
                         vis_name)
            break

    return test_accuracy.result()
コード例 #15
0
 def argmax(v, mask):
     return tf.argmax((v - tf.reduce_min(v) + 1) * mask, axis=0)