Example #1
0
def get_center_loss(features, labels,featuress, labelss ,alpha, num_classes):
    len_features = features.get_shape()[1]
    # 建立一个Variable,shape为[num_classes, len_features],用于存储整个网络的样本中心,
    # 设置trainable=False是因为样本中心不是由梯度进行更新的
    centers = tf.get_variable('centers', [num_classes, len_features], dtype=tf.float32,
        initializer=tf.constant_initializer(0), trainable=False)

    centerss = tf.get_variable('centerss', [num_classes, len_features], dtype=tf.float32,
                              initializer=tf.constant_initializer(0), trainable=False)
    # 将label展开为一维的,输入如果已经是一维的,则该动作其实无必要
    labels = tf.reshape(labels, [-1])
    labelss = tf.reshape(labelss, [-1])



##############################################################
    centers0=tf.unsorted_segment_mean(features,labels,num_classes)
    centers1=tf.unsorted_segment_mean(featuress,labelss,num_classes)
    EdgeWeights=tf.ones((num_classes,num_classes))-tf.eye(num_classes)
    margin=tf.constant(100,dtype="float32")
    norm = lambda x: tf.reduce_sum(tf.square(x), 1)
    center_pairwise_dist = tf.transpose(norm(tf.expand_dims(centers1, 2) - tf.transpose(centers0)))
    loss_0= tf.reduce_sum(tf.multiply(tf.maximum(0.0, margin-tf.transpose(norm(tf.expand_dims(centers0, 2) - tf.transpose(centers0)))),EdgeWeights))
    # + tf.reduce_sum(tf.maximum(0.0, tf.pow((centers1 - centers0), 2))) \
    # + 0.01*tf.reduce_sum(tf.multiply(tf.maximum(0.0, tf.constant(200,dtype="float32")-center_pairwise_dist),EdgeWeights))

    # 根据样本label,获取mini-batch中每一个样本对应的中心值
    centers_batch = tf.gather(centers, labels)
    # 当前mini-batch的特征值与它们对应的中心值之间的差
    diff = centers_batch - features
    unique_label, unique_idx, unique_count = tf.unique_with_counts(labels)
    appear_times = tf.gather(unique_count, unique_idx)
    appear_times = tf.reshape(appear_times, [-1, 1])
    diff = diff / tf.cast((1 + appear_times), tf.float32)
    diff = alpha * diff

    # # 根据样本label,获取mini-batch中每一个样本对应的中心值
    # centers_batch1 = tf.gather(centerss, labelss)
    # # 当前mini-batch的特征值与它们对应的中心值之间的差
    # diff1 = centers_batch1 - featuress
    # unique_label1, unique_idx1, unique_count1 = tf.unique_with_counts(labelss)
    # appear_times1 = tf.gather(unique_count1, unique_idx1)
    # appear_times1 = tf.reshape(appear_times1, [-1, 1])
    # diff1 = diff1 / tf.cast((1 + appear_times1), tf.float32)
    # diff1 = alpha * diff1



    # 计算loss
    loss_1 = tf.nn.l2_loss(features - centers_batch)
    centers_update_op= tf.scatter_sub(centers, labels, diff)
    # centers_update_op1= tf.scatter_sub(centerss, labelss, diff1)


    return loss_0, loss_1, centers_update_op
Example #2
0
def grouped_pairwise_margin_loss(similarities, group_ids, labels, gamma=1.0):
    """
    Calculates the pairwise margin ranking loss between paragraph of same question in a batch.
    Assumes at least (and most likely also at most) 2 paragraphs, one positive and one negative, for each question.
    """
    _, group_segments = tf.unique(group_ids)
    num_segments = tf.reduce_max(group_segments) + 1
    # we take advantage of the fact that negative segment ids get dropped
    # note: if label==1, 2*label-1 == 1, but if label==0, 2*label-1==-1
    positive_ranks = tf.unsorted_segment_mean(
        similarities, (group_segments + 1) * (2 * labels - 1) - 1,
        num_segments=num_segments)
    negative_ranks = tf.unsorted_segment_mean(
        similarities, (group_segments + 1) * (1 - 2 * labels) - 1,
        num_segments=num_segments)
    return tf.maximum(gamma - positive_ranks + negative_ranks, 0.)
Example #3
0
def get_center_loss(features, labels, alpha, num_classes):
    len_features = features.get_shape()[1]
    centers = tf.get_variable('centers', [num_classes, len_features],
                              dtype=tf.float32,
                              initializer=tf.constant_initializer(0),
                              trainable=False)
    labels = tf.reshape(labels, [-1])

    ##############################################################
    centers0 = tf.unsorted_segment_mean(features, labels, num_classes)
    EdgeWeights = tf.ones((num_classes, num_classes)) - tf.eye(num_classes)
    margin = tf.constant(100, dtype="float32")
    norm = lambda x: tf.reduce_sum(tf.square(x), 1)
    center_pairwise_dist = tf.transpose(
        norm(tf.expand_dims(centers0, 2) - tf.transpose(centers0)))
    loss_0 = tf.reduce_sum(
        tf.multiply(tf.maximum(0.0, margin - center_pairwise_dist),
                    EdgeWeights))
    ###########################################################################

    # 根据样本label,获取mini-batch中每一个样本对应的中心值
    centers_batch = tf.gather(centers, labels)
    # 当前mini-batch的特征值与它们对应的中心值之间的差
    diff = centers_batch - features
    unique_label, unique_idx, unique_count = tf.unique_with_counts(labels)
    appear_times = tf.gather(unique_count, unique_idx)
    appear_times = tf.reshape(appear_times, [-1, 1])
    diff = diff / tf.cast((1 + appear_times), tf.float32)
    diff = alpha * diff

    # 计算loss
    loss_1 = tf.nn.l2_loss(features - centers_batch)
    centers_update_op = tf.scatter_sub(centers, labels, diff)

    return loss_0, loss_1, centers_update_op, centers
Example #4
0
        def get_variable_embeddings(all_sequence_embeddings):
            flat_sequence_embeddings = tf.reshape(all_sequence_embeddings, (-1, all_sequence_embeddings.get_shape()[-1]))  # B*max-len x D
            target_token_embeddings = tf.gather(params=flat_sequence_embeddings,
                                                indices=self.placeholders['variable_bound_token_ids'])

            return tf.unsorted_segment_mean(
                data=target_token_embeddings,
                segment_ids=self.placeholders['token_variable_ids'],
                num_segments=self.placeholders['num_variables']  # TODO: Do not depend in any way on the classes.
            ) # num-variables x H
Example #5
0
    def classification_task_graphb4classify(self, last_h,
                                            classification_layer):
        _input = last_h  # [v x h]

        graph_representations = tf.unsorted_segment_mean(
            data=_input,
            segment_ids=self.placeholders['graph_nodes_list'],
            num_segments=self.placeholders['num_graphs'])  # [g x h]
        output = classification_layer(graph_representations)  # [g x 2]
        self.output = output
        return output
Example #6
0
def create_label(label, segments):
    n_node = tf.reduce_max(segments) + 1
    node_label = tf.cast(tf.unsorted_segment_mean(label, segments, n_node),
                         dtype=tf.float32)
    label_mean = tf.reduce_mean(tf.cast(label, dtype=tf.float32))

    a = tf.ones_like(node_label, dtype=tf.int32)
    b = tf.zeros_like(node_label, dtype=tf.int32)
    condition = tf.less(label_mean, node_label)
    label_t = tf.where(condition, a, b)
    label_t = tf.squeeze(label_t, axis=1)
    return label_t
Example #7
0
    def classification_task_org(self, last_h, classification_layer):
        _input = last_h  # [v x h]
        _output = classification_layer(_input)  # [v x 2]

        # Sum up all nodes per-graph
        #graph_representations = tf.unsorted_segment_sum(data=_output,
        graph_representations = tf.unsorted_segment_mean(
            data=_output,
            segment_ids=self.placeholders['graph_nodes_list'],
            num_segments=self.placeholders['num_graphs'])  # [g x 2]
        output = graph_representations
        self.output = output
        return output
Example #8
0
 def define_pooling(self, node_embeddings):
     if self.agg == 'sum':
         return tf.unsorted_segment_sum(
             data=node_embeddings,
             segment_ids=self.placeholders['node_graph_ids_list'],
             num_segments=self.placeholders['num_graphs'])
     elif self.agg == 'mean':
         return tf.unsorted_segment_mean(
             data=node_embeddings,
             segment_ids=self.placeholders['node_graph_ids_list'],
             num_segments=self.placeholders['num_graphs'])
     else:
         raise ValueError("Aggregation must be one of {'sum', 'mean'}")
Example #9
0
    def define_round(self, layer: int, time_step: int, node_embeddings):
        node_embeddings = tf.identity(node_embeddings)
        edge_weights = self.weights['edge_weights'][layer]
        src_node_ids, dst_node_ids, src_node_embeddings, dst_node_embeddings, messages = [], [], [], [], []

        for e_type, adj_list in enumerate(self.placeholders['adjacency_lists']):
            src_node_ids.append(adj_list[:, 0])
            dst_node_ids.append(adj_list[:, 1])

            src_node_embeddings.append(tf.nn.embedding_lookup(params=node_embeddings, ids=src_node_ids[-1]))
            dst_node_embeddings.append(tf.nn.embedding_lookup(params=node_embeddings, ids=dst_node_ids[-1]))

            messages.append(tf.matmul(src_node_embeddings[-1], edge_weights[e_type]))

        src_node_ids = tf.concat(src_node_ids, axis=0)
        src_node_embeddings = tf.concat(src_node_embeddings, axis=0)
        dst_node_ids = tf.concat(dst_node_ids, axis=0)
        dst_node_embeddings = tf.concat(dst_node_embeddings, axis=0)
        messages = tf.concat(messages, axis=0)

        #  Now weigh the messages using attention if configured to do so
        if self.use_propagation_attention:
            messages *= tf.expand_dims(self.define_message_attention(layer, src_node_ids, src_node_embeddings,
                                                                     dst_node_ids, dst_node_embeddings, messages), axis=-1)

        #  Accumulate all messages for a destination node
        if self.edge_msg_aggregation == 'avg':
            incoming_messages = tf.unsorted_segment_mean(data=messages,
                                                         segment_ids=dst_node_ids,
                                                         num_segments=self.placeholders['num_nodes'])
        elif self.edge_msg_aggregation == 'sum':
            incoming_messages = tf.unsorted_segment_sum(data=messages,
                                                        segment_ids=dst_node_ids,
                                                        num_segments=self.placeholders['num_nodes'])
        else:
            raise ValueError("Edge message aggregation type should be one of {'avg', 'sum'}")

        #  Compute new node states i.e. states for the next round of message passing (if any)
        return self.weights['rnn_cells'][layer](incoming_messages, [node_embeddings])[0]
Example #10
0
    def update_memory(self, solver):
        # update memory bank after solver
        with tf.control_dependencies([solver]):
            with tf.name_scope('update_point_memory'):
                feature = self.feature
                seg_num, point_id = self.flags.seg_num, self.point_id
                # point_mask = point_id > -1  # filter label -1
                point_id = point_id + (self.obj_segment * seg_num)
                # point_id = tf.boolean_mask(point_id, point_mask)
                # feature = tf.boolean_mask(feature, point_mask)

                batch_size = self.batch_size
                feature = tf.unsorted_segment_mean(feature, point_id,
                                                   seg_num * batch_size)
                feature = tf.nn.l2_normalize(feature, axis=1)
                feature = tf.reshape(feature, [batch_size, seg_num, -1])

                momentum = self.flags.momentum
                weight = tf.gather(self.memory, self.shape_id)
                weight = feature * momentum + weight * (1 - momentum)
                weight = tf.nn.l2_normalize(weight, axis=2)
                memory = tf.scatter_update(self.memory, self.shape_id, weight)
        return memory
Example #11
0
    def construct(self, v_dims=[2, 4, 8], e_dims=[2, 4, 8], g_dims=[4, 8, 16]):
        with self.session.graph.as_default():
            assert(len(v_dims) == len(e_dims))
            assert(len(v_dims) == len(g_dims))

            # Input graph order and size
            self.g_n = tf.placeholder(tf.int32, [], name="g_n")
            self.g_n_f = tf.cast(self.g_n, tf.float32)
            self.g_m = tf.placeholder(tf.int32, [], name="g_m")
            self.g_m_f = tf.cast(self.g_m, tf.float32)

            # Input graph edge sources and destinations
            self.edge_srcs = tf.placeholder(tf.int32, [None], name="edge_srcs")
            self.edge_dsts = tf.placeholder(tf.int32, [None], name="edge_dsts")

            # Initial values
            self.v_layer = tf.ones([self.g_n, 1])
            self.g_layer = tf.ones([1, 0])
            self.e_layer = tf.ones([self.g_m, 0])

            for i in range(len(v_dims)):

                # For every edge, the input and output vertex
                e_in_layer = tf.gather(self.v_layer, self.edge_srcs, axis=0)
                e_out_layer = tf.gather(self.v_layer, self.edge_dsts, axis=0)
                e_layer_new = self.edge_gadget(self.g_layer, self.e_layer, e_in_layer, e_out_layer, dim=e_dims[i])

                v_sum = tf.unsorted_segment_sum(e_layer_new, self.edge_dsts, self.g_n)
                v_max = tf.maximum(tf.unsorted_segment_max(e_layer_new, self.edge_dsts, self.g_n), 0.0)
                v_mean = tf.unsorted_segment_mean(e_layer_new, self.edge_dsts, self.g_n)
                v_layer_new = self.vertex_gadget(self.g_layer, self.v_layer, v_sum, v_mean, v_max, dim=v_dims[i])

                g_layer_new = self.graph_gadget(self.g_layer, self.v_layer, dim=g_dims[i])

                self.g_layer = g_layer_new
                self.v_layer = v_layer_new
                self.e_layer = e_layer_new
Example #12
0
def main(unused_args):
    assert len(unused_args) == 1, unused_args
    setup_experiment(logging, FLAGS, "critic_model")

    if FLAGS.validation:
        mnist_ds = mnist.read_data_sets(FLAGS.data_dir,
                                        dtype=tf.float32,
                                        reshape=False,
                                        validation_size=0)
        val_ds = mnist_ds.test
    else:
        mnist_ds = mnist.read_data_sets(FLAGS.data_dir,
                                        dtype=tf.float32,
                                        reshape=False,
                                        validation_size=FLAGS.validation_size)
        val_ds = mnist_ds.validation
    train_ds = mnist_ds.train
    val_ds = mnist_ds.validation
    test_ds = mnist_ds.test
    num_classes = FLAGS.num_classes

    img_shape = [None, 1, 28, 28]
    X = tf.placeholder(tf.float32, shape=img_shape, name='X')
    # placeholder to avoid recomputation of adversarial images for critic
    X_hat_h = tf.placeholder(tf.float32, shape=img_shape, name='X_hat')
    y = tf.placeholder(tf.int32, shape=[None], name='y')
    y_onehot = tf.one_hot(y, num_classes)
    reduce_ind = list(range(1, X.get_shape().ndims))
    # test/validation inputs
    X_v = tf.placeholder(tf.float32, shape=img_shape, name='X_v')
    y_v = tf.placeholder(tf.int32, shape=[None], name='y_v')
    y_v_onehot = tf.one_hot(y_v, num_classes)

    # classifier model
    model = create_model(FLAGS, name=FLAGS.model_name)

    def test_model(x, **kwargs):
        return model(x, train=False, **kwargs)

    # generator
    def generator(inputs, confidence, targets=None):
        return high_confidence_attack_unrolled(
            lambda x: model(x)['logits'],
            inputs,
            targets=targets,
            confidence=confidence,
            max_iter=FLAGS.attack_iter,
            over_shoot=FLAGS.attack_overshoot,
            attack_random=FLAGS.attack_random,
            attack_uniform=FLAGS.attack_uniform,
            attack_label_smoothing=FLAGS.attack_label_smoothing)

    def test_generator(inputs, confidence, targets=None):
        return high_confidence_attack(lambda x: test_model(x)['logits'],
                                      inputs,
                                      targets=targets,
                                      confidence=confidence,
                                      max_iter=FLAGS.df_iter,
                                      over_shoot=FLAGS.df_overshoot,
                                      random=FLAGS.attack_random,
                                      uniform=FLAGS.attack_uniform,
                                      clip_dist=FLAGS.df_clip)

    # discriminator
    critic = create_model(FLAGS, prefix='critic_', name='critic')

    # classifier outputs
    outs_x = model(X)
    outs_x_v = test_model(X_v)
    params = tf.trainable_variables()
    model_weights = [param for param in params if "weights" in param.name]
    vars = tf.model_variables()
    target_conf_v = [None]

    if FLAGS.attack_confidence == "same":
        # set the target confidence to the confidence of the original prediction
        target_confidence = outs_x['conf']
        target_conf_v[0] = target_confidence
    elif FLAGS.attack_confidence == "class_running_mean":
        # set the target confidence to the mean confidence of the specific target
        # use running mean estimate
        class_conf_mean = tf.Variable(np.ones(num_classes, dtype=np.float32))
        batch_conf_mean = tf.unsorted_segment_mean(outs_x['conf'],
                                                   outs_x['pred'], num_classes)
        # if batch does not contain predictions for the specific target
        # (zeroes), replace zeroes with stored class mean (previous batch)
        batch_conf_mean = tf.where(tf.not_equal(batch_conf_mean, 0),
                                   batch_conf_mean, class_conf_mean)
        # update class confidence mean
        class_conf_mean = assign_moving_average(class_conf_mean,
                                                batch_conf_mean, 0.5)
        # init class confidence during pre-training
        tf.add_to_collection("PREINIT_OPS", class_conf_mean)

        def target_confidence(targets_onehot):
            targets = tf.argmax(targets_onehot, axis=1)
            check_conf = tf.Assert(
                tf.reduce_all(tf.not_equal(class_conf_mean, 0)),
                [class_conf_mean])
            with tf.control_dependencies([check_conf]):
                t = tf.gather(class_conf_mean, targets)
            target_conf_v[0] = t
            return tf.stop_gradient(t)
    else:
        target_confidence = float(FLAGS.attack_confidence)
        target_conf_v[0] = target_confidence

    X_hat = generator(X, target_confidence)
    outs_x_hat = model(X_hat)
    # select examples for which attack succeeded (changed the prediction)
    X_hat_filter = tf.not_equal(outs_x['pred'], outs_x_hat['pred'])
    X_hat_f = tf.boolean_mask(X_hat, X_hat_filter)
    X_f = tf.boolean_mask(X, X_hat_filter)

    outs_x_f = model(X_f)
    outs_x_hat_f = model(X_hat_f)
    X_hatd = tf.stop_gradient(X_hat)
    X_rec = generator(X_hatd, outs_x['conf'], outs_x['pred'])
    X_rec_f = tf.boolean_mask(X_rec, X_hat_filter)

    # validation/test adversarial examples
    X_v_hat = test_generator(X_v, FLAGS.val_attack_confidence)
    X_v_hatd = tf.stop_gradient(X_v_hat)
    X_v_rec = test_generator(X_v_hatd,
                             outs_x_v['conf'],
                             targets=outs_x_v['pred'])
    X_v_hat_df = deepfool(lambda x: test_model(x)['logits'],
                          X_v,
                          y_v,
                          max_iter=FLAGS.df_iter,
                          clip_dist=FLAGS.df_clip)
    X_v_hat_df_all = deepfool(lambda x: test_model(x)['logits'],
                              X_v,
                              max_iter=FLAGS.df_iter,
                              clip_dist=FLAGS.df_clip)

    y_hat = outs_x['pred']
    y_adv = outs_x_hat['pred']
    y_adv_f = outs_x_hat_f['pred']
    tf.summary.histogram('y_data', y, collections=["model_summaries"])
    tf.summary.histogram('y_hat', y_hat, collections=["model_summaries"])
    tf.summary.histogram('y_adv', y_adv, collections=["model_summaries"])

    # critic outputs
    critic_outs_x = critic(X)
    critic_outs_x_hat = critic(X_hat_f)
    critic_params = list(set(tf.trainable_variables()) - set(params))
    critic_vars = list(set(tf.trainable_variables()) - set(vars))

    # binary logits for a specific target
    logits_data = critic_outs_x['logits']
    logits_data_flt = tf.reshape(logits_data, (-1, ))
    z_data = tf.gather(logits_data_flt,
                       tf.range(tf.shape(X)[0]) * num_classes + y)
    logits_adv = critic_outs_x_hat['logits']
    logits_adv_flt = tf.reshape(logits_adv, (-1, ))
    z_adv = tf.gather(logits_adv_flt,
                      tf.range(tf.shape(X_hat_f)[0]) * num_classes + y_adv_f)

    # classifier/generator losses
    nll = tf.reduce_mean(
        tf.losses.softmax_cross_entropy(y_onehot, outs_x['logits']))
    nll_v = tf.reduce_mean(
        tf.losses.softmax_cross_entropy(y_v_onehot, outs_x_v['logits']))
    # gan losses
    gan = tf.losses.sigmoid_cross_entropy(tf.ones_like(z_adv), z_adv)
    rec_l1 = tf.reduce_mean(
        tf.reduce_sum(tf.abs(X_f - X_rec_f), axis=reduce_ind))
    rec_l2 = tf.reduce_mean(tf.reduce_sum((X_f - X_rec_f)**2, axis=reduce_ind))

    weight_decay = slim.apply_regularization(slim.l2_regularizer(1.0),
                                             model_weights[:-1])
    pretrain_loss = nll + 5e-6 * weight_decay
    loss = nll + FLAGS.lmbd * gan
    if FLAGS.lmbd_rec_l1 > 0:
        loss += FLAGS.lmbd_rec_l1 * rec_l1
    if FLAGS.lmbd_rec_l2 > 0:
        loss += FLAGS.lmbd_rec_l2 * rec_l2
    if FLAGS.weight_decay > 0:
        loss += FLAGS.weight_decay * weight_decay

    # critic loss
    critic_gan_data = tf.losses.sigmoid_cross_entropy(tf.ones_like(z_data),
                                                      z_data)
    # use placeholder for X_hat to avoid recomputation of adversarial noise
    y_adv_h = model(X_hat_h)['pred']
    logits_adv_h = critic(X_hat_h)['logits']
    logits_adv_flt_h = tf.reshape(logits_adv_h, (-1, ))
    z_adv_h = tf.gather(logits_adv_flt_h,
                        tf.range(tf.shape(X_hat_h)[0]) * num_classes + y_adv_h)
    critic_gan_adv = tf.losses.sigmoid_cross_entropy(tf.zeros_like(z_adv_h),
                                                     z_adv_h)
    critic_gan = critic_gan_data + critic_gan_adv

    # Gulrajani discriminator regularizer (we do not interpolate)
    critic_grad_data = tf.gradients(z_data, X)[0]
    critic_grad_adv = tf.gradients(z_adv_h, X_hat_h)[0]
    critic_grad_penalty = norm_penalty(critic_grad_adv) + norm_penalty(
        critic_grad_data)
    critic_loss = critic_gan + FLAGS.lmbd_grad * critic_grad_penalty

    # classifier model_metrics
    err = 1 - slim.metrics.accuracy(outs_x['pred'], y)
    conf = tf.reduce_mean(outs_x['conf'])
    err_hat = 1 - slim.metrics.accuracy(
        test_model(X_hat)['pred'], outs_x['pred'])
    err_hat_f = 1 - slim.metrics.accuracy(
        test_model(X_hat_f)['pred'], outs_x_f['pred'])
    err_rec = 1 - slim.metrics.accuracy(
        test_model(X_rec)['pred'], outs_x['pred'])
    conf_hat = tf.reduce_mean(test_model(X_hat)['conf'])
    conf_hat_f = tf.reduce_mean(test_model(X_hat_f)['conf'])
    conf_rec = tf.reduce_mean(test_model(X_rec)['conf'])
    err_v = 1 - slim.metrics.accuracy(outs_x_v['pred'], y_v)
    conf_v_hat = tf.reduce_mean(test_model(X_v_hat)['conf'])
    l2_hat = tf.sqrt(tf.reduce_sum((X_f - X_hat_f)**2, axis=reduce_ind))
    tf.summary.histogram('l2_hat', l2_hat, collections=["model_summaries"])

    # critic model_metrics
    critic_err_data = 1 - binary_accuracy(
        z_data, tf.ones(tf.shape(z_data), tf.bool), 0.0)
    critic_err_adv = 1 - binary_accuracy(
        z_adv, tf.zeros(tf.shape(z_adv), tf.bool), 0.0)

    # validation model_metrics
    err_df = 1 - slim.metrics.accuracy(test_model(X_v_hat_df)['pred'], y_v)
    err_df_all = 1 - slim.metrics.accuracy(
        test_model(X_v_hat_df_all)['pred'], outs_x_v['pred'])
    l2_v_hat = tf.sqrt(tf.reduce_sum((X_v - X_v_hat)**2, axis=reduce_ind))
    l2_v_rec = tf.sqrt(tf.reduce_sum((X_v - X_v_rec)**2, axis=reduce_ind))
    l1_v_rec = tf.reduce_sum(tf.abs(X_v - X_v_rec), axis=reduce_ind)
    l2_df = tf.sqrt(tf.reduce_sum((X_v - X_v_hat_df)**2, axis=reduce_ind))
    l2_df_norm = l2_df / tf.sqrt(tf.reduce_sum(X_v**2, axis=reduce_ind))
    l2_df_all = tf.sqrt(
        tf.reduce_sum((X_v - X_v_hat_df_all)**2, axis=reduce_ind))
    l2_df_norm_all = l2_df_all / tf.sqrt(tf.reduce_sum(X_v**2,
                                                       axis=reduce_ind))
    tf.summary.histogram('l2_df', l2_df, collections=["adv_summaries"])
    tf.summary.histogram('l2_df_norm',
                         l2_df_norm,
                         collections=["adv_summaries"])

    # model_metrics
    pretrain_model_metrics = OrderedDict([('nll', nll),
                                          ('weight_decay', weight_decay),
                                          ('err', err)])
    model_metrics = OrderedDict([('loss', loss), ('nll', nll),
                                 ('l2_hat', tf.reduce_mean(l2_hat)),
                                 ('gan', gan), ('rec_l1', rec_l1),
                                 ('rec_l2', rec_l2),
                                 ('weight_decay', weight_decay), ('err', err),
                                 ('conf', conf), ('err_hat', err_hat),
                                 ('err_hat_f', err_hat_f),
                                 ('conf_t', tf.reduce_mean(target_conf_v[0])),
                                 ('conf_hat', conf_hat),
                                 ('conf_hat_f', conf_hat_f),
                                 ('err_rec', err_rec), ('conf_rec', conf_rec)])
    critic_metrics = OrderedDict([('c_loss', critic_loss),
                                  ('c_gan', critic_gan),
                                  ('c_gan_data', critic_gan_data),
                                  ('c_gan_adv', critic_gan_adv),
                                  ('c_grad_norm', critic_grad_penalty),
                                  ('c_err_adv', critic_err_adv),
                                  ('c_err_data', critic_err_data)])
    val_metrics = OrderedDict([('nll', nll_v), ('err', err_v)])
    adv_metrics = OrderedDict([('l2_df', tf.reduce_mean(l2_df)),
                               ('l2_df_norm', tf.reduce_mean(l2_df_norm)),
                               ('l2_df_all', tf.reduce_mean(l2_df_all)),
                               ('l2_df_all_norm',
                                tf.reduce_mean(l2_df_norm_all)),
                               ('l2_hat', tf.reduce_mean(l2_v_hat)),
                               ('conf_hat', conf_v_hat),
                               ('l1_rec', tf.reduce_mean(l1_v_rec)),
                               ('l2_rec', tf.reduce_mean(l2_v_rec)),
                               ('err_df', err_df), ('err_df_all', err_df_all)])

    pretrain_metric_mean, pretrain_metric_upd = register_metrics(
        pretrain_model_metrics, collections="pretrain_model_summaries")
    metric_mean, metric_upd = register_metrics(model_metrics,
                                               collections="model_summaries")
    critic_metric_mean, critic_metric_upd = register_metrics(
        critic_metrics, collections="critic_summaries")
    val_metric_mean, val_metric_upd = register_metrics(
        val_metrics, prefix="val_", collections="val_summaries")
    adv_metric_mean, adv_metric_upd = register_metrics(
        adv_metrics, collections="adv_summaries")
    metrics_reset = tf.variables_initializer(tf.local_variables())

    # training ops
    lr = tf.Variable(FLAGS.lr, trainable=False)
    critic_lr = tf.Variable(FLAGS.critic_lr, trainable=False)
    tf.summary.scalar('lr', lr, collections=["model_summaries"])
    tf.summary.scalar('critic_lr', critic_lr, collections=["critic_summaries"])

    optimizer = tf.train.AdamOptimizer(learning_rate=lr, beta1=0.5)

    preinit_ops = tf.get_collection("PREINIT_OPS")
    with tf.control_dependencies(preinit_ops):
        pretrain_solver = optimizer.minimize(pretrain_loss, var_list=params)
    solver = optimizer.minimize(loss, var_list=params)
    critic_solver = (tf.train.AdamOptimizer(
        learning_rate=critic_lr, beta1=0.5).minimize(critic_loss,
                                                     var_list=critic_params))

    # train
    summary_images, summary_labels = select_balanced_subset(
        train_ds.images, train_ds.labels, num_classes, num_classes)
    summary_images = summary_images.transpose((0, 3, 1, 2))
    save_path = os.path.join(FLAGS.samples_dir, 'orig.png')
    save_images(summary_images, save_path)

    if FLAGS.gpu_memory < 1.0:
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=FLAGS.gpu_memory)
        config = tf.ConfigProto(gpu_options=gpu_options)
    else:
        config = None
    with tf.Session(config=config) as sess:
        try:
            # summaries
            summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
            summaries = tf.summary.merge_all("model_summaries")
            critic_summaries = tf.summary.merge_all("critic_summaries")
            val_summaries = tf.summary.merge_all("val_summaries")
            adv_summaries = tf.summary.merge_all("adv_summaries")

            # initialization
            tf.local_variables_initializer().run()
            tf.global_variables_initializer().run()

            # pretrain model
            if FLAGS.pretrain_niter > 0:
                logging.info("Model pretraining")
                for epoch in range(1, FLAGS.pretrain_niter + 1):
                    train_iterator = batch_iterator(train_ds.images,
                                                    train_ds.labels,
                                                    FLAGS.batch_size,
                                                    shuffle=True)
                    sess.run(metrics_reset)

                    start_time = time.time()
                    for ind, (images, labels) in enumerate(train_iterator):
                        sess.run([pretrain_solver, pretrain_metric_upd],
                                 feed_dict={
                                     X: images,
                                     y: labels
                                 })

                    str_bfr = six.StringIO()
                    str_bfr.write("Pretrain epoch [{}, {:.2f}s]:".format(
                        epoch,
                        time.time() - start_time))
                    print_results_str(str_bfr, pretrain_model_metrics.keys(),
                                      sess.run(pretrain_metric_mean))
                    print_results_str(str_bfr, critic_metrics.keys(),
                                      sess.run(critic_metric_mean))
                    logging.info(str_bfr.getvalue()[:-1])

            # training
            for epoch in range(1, FLAGS.niter + 1):
                train_iterator = batch_iterator(train_ds.images,
                                                train_ds.labels,
                                                FLAGS.batch_size,
                                                shuffle=True)
                sess.run(metrics_reset)

                start_time = time.time()
                for ind, (images, labels) in enumerate(train_iterator):
                    batch_index = (epoch - 1) * (train_ds.images.shape[0] //
                                                 FLAGS.batch_size) + ind
                    # train critic for several steps
                    X_hat_np = sess.run(X_hat, feed_dict={X: images})
                    for _ in range(FLAGS.critic_steps - 1):
                        sess.run([critic_solver],
                                 feed_dict={
                                     X: images,
                                     y: labels,
                                     X_hat_h: X_hat_np
                                 })
                    else:
                        summary = sess.run([
                            critic_solver, critic_metric_upd, critic_summaries
                        ],
                                           feed_dict={
                                               X: images,
                                               y: labels,
                                               X_hat_h: X_hat_np
                                           })[-1]
                        summary_writer.add_summary(summary, batch_index)
                    # train model
                    summary = sess.run([solver, metric_upd, summaries],
                                       feed_dict={
                                           X: images,
                                           y: labels
                                       })[-1]
                    summary_writer.add_summary(summary, batch_index)

                str_bfr = six.StringIO()
                str_bfr.write("Train epoch [{}, {:.2f}s]:".format(
                    epoch,
                    time.time() - start_time))
                print_results_str(str_bfr, model_metrics.keys(),
                                  sess.run(metric_mean))
                print_results_str(str_bfr, critic_metrics.keys(),
                                  sess.run(critic_metric_mean))
                logging.info(str_bfr.getvalue()[:-1])

                val_iterator = batch_iterator(val_ds.images,
                                              val_ds.labels,
                                              100,
                                              shuffle=False)
                for images, labels in val_iterator:
                    summary = sess.run([val_metric_upd, val_summaries],
                                       feed_dict={
                                           X_v: images,
                                           y_v: labels
                                       })[-1]
                    summary_writer.add_summary(summary, epoch)
                str_bfr = six.StringIO()
                str_bfr.write("Valid epoch [{}]:".format(epoch))
                print_results_str(str_bfr, val_metrics.keys(),
                                  sess.run(val_metric_mean))
                logging.info(str_bfr.getvalue()[:-1])

                # learning rate decay
                update_lr = lr_decay(lr, epoch)
                if update_lr is not None:
                    sess.run(update_lr)
                    logging.debug(
                        "learning rate was updated to: {:.10f}".format(
                            lr.eval()))
                critic_update_lr = lr_decay(critic_lr, epoch, prefix='critic_')
                if critic_update_lr is not None:
                    sess.run(critic_update_lr)
                    logging.debug(
                        "critic learning rate was updated to: {:.10f}".format(
                            critic_lr.eval()))

                if epoch % FLAGS.summary_frequency == 0:
                    samples_hat, samples_rec, samples_df, summary = sess.run(
                        [
                            X_v_hat, X_v_rec, X_v_hat_df, adv_summaries,
                            adv_metric_upd
                        ],
                        feed_dict={
                            X_v: summary_images,
                            y_v: summary_labels
                        })[:-1]
                    summary_writer.add_summary(summary, epoch)
                    save_path = os.path.join(FLAGS.samples_dir,
                                             'epoch_orig-%d.png' % epoch)
                    save_images(summary_images, save_path)
                    save_path = os.path.join(FLAGS.samples_dir,
                                             'epoch-%d.png' % epoch)
                    save_images(samples_hat, save_path)
                    save_path = os.path.join(FLAGS.samples_dir,
                                             'epoch_rec-%d.png' % epoch)
                    save_images(samples_rec, save_path)
                    save_path = os.path.join(FLAGS.samples_dir,
                                             'epoch_df-%d.png' % epoch)
                    save_images(samples_df, save_path)

                    str_bfr = six.StringIO()
                    str_bfr.write("Summary epoch [{}]:".format(epoch))
                    print_results_str(str_bfr, adv_metrics.keys(),
                                      sess.run(adv_metric_mean))
                    logging.info(str_bfr.getvalue()[:-1])

                if FLAGS.checkpoint_frequency != -1 and epoch % FLAGS.checkpoint_frequency == 0:
                    save_checkpoint(sess, vars, epoch=epoch)
                    save_checkpoint(sess,
                                    critic_vars,
                                    name="critic_model",
                                    epoch=epoch)
        except KeyboardInterrupt:
            logging.debug("Keyboard interrupt. Stopping training...")
        except NanError as e:
            logging.info(e)
        finally:
            sess.run(metrics_reset)
            save_checkpoint(sess, vars)
            save_checkpoint(sess, critic_vars, name="critic_model")

        # final accuracy
        test_iterator = batch_iterator(test_ds.images,
                                       test_ds.labels,
                                       100,
                                       shuffle=False)
        for images, labels in test_iterator:
            sess.run([val_metric_upd], feed_dict={X_v: images, y_v: labels})
        str_bfr = six.StringIO()
        str_bfr.write("Final epoch [{}]:".format(epoch))
        for metric_name, metric_value in zip(val_metrics.keys(),
                                             sess.run(val_metric_mean)):
            str_bfr.write(" {}: {:.6f},".format(metric_name, metric_value))
        logging.info(str_bfr.getvalue()[:-1])
Example #13
0
    def create_prot_pooling(self, pInFeatures, pProtein, pLevel, pBNAFDO):
        """Method to create a protein pooling operation.

        Args:
            pInFeatures (float tensor nxf): Input features.
            pProtein (Protein): Protein.
            pLevel (int): Level we want to pool.
            pBNAFDO (BNAFDO): BNAFDO object.
        Returns:
            (float tensor n'xf): Pooled features to the level pLevel+1.
        """

        if pProtein.poolType_[pLevel - 1] == "GRA":

            poolFeatures = self.graphConvBuilder_.create_graph_aggregation(
                pInFeatures=pInFeatures,
                pGraph=pProtein.molObjects_[pLevel - 1].graph_,
                pNormalize=True,
                pSpectralApprox=False)
            poolFeatures = tf.gather(poolFeatures,
                                     pProtein.poolIds_[pLevel - 1])

        elif pProtein.poolType_[pLevel - 1] == "AVG":

            poolFeatures = tf.unsorted_segment_mean(
                pInFeatures, pProtein.poolIds_[pLevel - 1],
                tf.shape(pProtein.molObjects_[pLevel].batchIds_)[0])

        elif pProtein.poolType_[pLevel - 1].startswith("GRAPH_DROP"):

            maskValueBool, poolFeatures, newGraph = self.graphConvBuilder_.create_graph_node_pooling(
                "Graph_drop_pooling_" + str(pLevel),
                pProtein.molObjects_[pLevel - 1].batchIds_,
                pProtein.molObjects_[pLevel - 1].graph_, pInFeatures,
                pProtein.molObjects_[pLevel - 1].batchSize_, 0.5, pBNAFDO)

            newPos = tf.boolean_mask(pProtein.molObjects_[pLevel - 1].pc_.pts_,
                                     maskValueBool)
            newBatchIds = tf.boolean_mask(
                pProtein.molObjects_[pLevel - 1].batchIds_, maskValueBool)

            if pProtein.molObjects_[pLevel - 1].graph2_ is None:
                newGraph2 = Graph(None, None)
            else:
                newGraph2 = pProtein.molObjects_[
                    pLevel - 1].graph2_.pool_graph_drop_nodes(
                        maskValueBool,
                        tf.shape(newPos)[0])

            pProtein.molObjects_[pLevel] = Molecule(
                newPos, newGraph.neighbors_, newGraph.nodeStartIndexs_,
                newBatchIds, pProtein.molObjects_[pLevel - 1].batchSize_,
                newGraph2.neighbors_, newGraph2.nodeStartIndexs_)

            if pProtein.poolType_[pLevel - 1] == "GRAPH_DROP_AMINO":
                pProtein.poolIds_[pLevel] = tf.boolean_mask(
                    pProtein.atomAminoIds_, maskValueBool)

        elif pProtein.poolType_[pLevel - 1].startswith("GRAPH_EDGE"):

            newIndices, poolFeatures, newGraph = self.graphConvBuilder_.create_graph_edge_pooling(
                "Graph_edge_pooling_" + str(pLevel),
                pProtein.molObjects_[pLevel - 1].graph_, pInFeatures, pBNAFDO)

            newPos = tf.unsorted_segment_mean(
                pProtein.molObjects_[pLevel - 1].pc_.pts_, newIndices,
                tf.shape(poolFeatures)[0])
            newBatchIds = tf.unsorted_segment_max(
                pProtein.molObjects_[pLevel - 1].batchIds_, newIndices,
                tf.shape(poolFeatures)[0])

            if pProtein.molObjects_[pLevel - 1].graph2_ is None:
                newGraph2 = Graph(None, None)
            else:
                newGraph2 = pProtein.molObjects_[
                    pLevel - 1].graph2_.pool_graph_collapse_edges(
                        newIndices,
                        tf.shape(newPos)[0])

            pProtein.molObjects_[pLevel] = Molecule(
                newPos, newGraph.neighbors_, newGraph.nodeStartIndexs_,
                newBatchIds, pProtein.molObjects_[pLevel - 1].batchSize_,
                newGraph2.neighbors_, newGraph2.nodeStartIndexs_)

            if pProtein.poolType_[pLevel - 1] == "GRAPH_EDGE_AMINO":
                pProtein.poolIds_[pLevel] = tf.unsorted_segment_max(
                    pProtein.atomAminoIds_, newIndices,
                    tf.shape(poolFeatures)[0])

        return poolFeatures
Example #14
0
    def __bio_pooling__(self,
        pAminoInput,
        pAtomPos, 
        pBatchIds,
        pPoolIds,
        pGraph1Neighbors, 
        pGraph1NeighStartIds, 
        pGraph2Neighbors, 
        pGraph2NeighStartIds, 
        pBatchSize,
        pConfig,
        pAminoPos = None,
        pAtomAminoIds = None, 
        pAtomResidueIds = None,
        pNumResidues = None):
        """Biological inspired pooling.

        Args:
            pAminoInput (bool): Boolean that indicates if the protein is represented at
                aminoacid level.
            pAtomPos (float tensor nxd): List of atom positions.
            pBatchIds (int tensor n): List of batch ids.
            pPoolIds (list int tensor n): List of pool ids for each level.
            pGraph1Neighbors (list int tensor mx2): List of neighbor pairs. 
            pGraph1NeighStartIds (list int tensor n): List of starting indices for each atom
                in the neighboring list.
            pGraph2Neighbors (list int tensor mx2): List of neighbor pairs. 
            pGraph2NeighStartIds (list int tensor n): List of starting indices for each atom
                in the neighboring list.
            pBatchSize (int): Size of the batch.
            pConfig (dictionary): Dictionary with the config parameters.
            pAminoPos (float tensor n'x3): Aminoacid pos.
            pAtomAminoIds (int tensor n): Identifier of the aminoacid per each atom.
            pAtomResidueIds (int tensor n): Identifier of the residue per each atom.
            pNumResidues (int): Number of residues.
        """
        numBBPooling = int(pConfig['prot.numbbpoolings'])

        self.aminoInput_ = pAminoInput
        self.poolIds_ = [curPoolId for curPoolId in pPoolIds]

        # Save the first molecule object.
        self.molObjects_ = [
            Molecule(
                pAtomPos, 
                pGraph1Neighbors[0], 
                pGraph1NeighStartIds[0], 
                pBatchIds, pBatchSize,
                pGraph2Neighbors[0], 
                pGraph2NeighStartIds[0])]

        # If input not aminoacids.
        if not self.aminoInput_:
            
            # Save side chain poolings.
            curAtomAminoIds = pAtomAminoIds
            curPoolIds = pPoolIds[0]
            newPos = tf.unsorted_segment_mean(
                self.molObjects_[0].atomPos_, curPoolIds,
                tf.shape(pGraph1NeighStartIds[1])[0])
            curAtomAminoIds = tf.unsorted_segment_max(
                curAtomAminoIds, curPoolIds, 
                tf.shape(pGraph1NeighStartIds[1])[0])
            newBatchIds = tf.unsorted_segment_max(
                self.molObjects_[0].batchIds_, curPoolIds,
                tf.shape(pGraph1NeighStartIds[1])[0])
            self.molObjects_.append(
                Molecule(
                newPos, 
                pGraph1Neighbors[1], 
                pGraph1NeighStartIds[1], 
                newBatchIds, pBatchSize,
                pGraph2Neighbors[1], 
                pGraph2NeighStartIds[1]))
            self.poolType_.append('AVG')
                
            # Save aminoacid level.
            aminoBatchIds = tf.unsorted_segment_max(
                    pBatchIds, pAtomAminoIds,
                    tf.shape(pAminoPos)[0])
            self.molObjects_.append(
                Molecule(
                    pAminoPos, 
                    pGraph1Neighbors[-1], 
                    pGraph1NeighStartIds[-1], 
                    aminoBatchIds, pBatchSize,
                    pGraph2Neighbors[-1], 
                    pGraph2NeighStartIds[-1]))
            self.poolIds_.append(curAtomAminoIds)
            self.poolType_.append('AVG')

        # Compute the backbone poolings.
        for curPool in range(numBBPooling):
            selIndices, pooledNeighs, pooledStartIds = \
                compute_protein_pooling(self.molObjects_[-1].graph_)

            newPositions = compute_graph_aggregation(
                self.molObjects_[-1].graph_, 
                self.molObjects_[-1].atomPos_,
                True)
            newPositions = tf.gather(newPositions, selIndices)

            self.poolIds_.append(selIndices)
            self.poolType_.append('GRA')

            selMask = tf.scatter_nd(
                tf.reshape(selIndices, [-1, 1]),
                tf.ones_like(selIndices),
                tf.shape(self.molObjects_[-1].batchIds_))
            pooledIndices = tf.cumsum(selMask)-1

            # Create new graph2.
            newGraph2 = self.molObjects_[-1].graph2_.pool_graph_collapse_edges(
                pooledIndices, tf.shape(selIndices)[0])
            pooledNeighs2 = newGraph2.neighbors_
            pooledStartIds2 = newGraph2.nodeStartIndexs_                

            newBatchIds = tf.gather(self.molObjects_[-1].batchIds_, selIndices)
            self.molObjects_.append(Molecule(
                newPositions, pooledNeighs, pooledStartIds, 
                newBatchIds, pBatchSize,
                pooledNeighs2, pooledStartIds2))

        self.atomAminoIds_ = pAtomAminoIds
        self.atomResidueIds_ = pAtomResidueIds
        self.numResidues_ = pNumResidues
Example #15
0
def bucket_mean(data, bucket_ids, num_buckets):
    total = tf.unsorted_segment_mean(data, bucket_ids, num_buckets)
    count = tf.unsorted_segment_mean(tf.ones_like(data), bucket_ids,
                                     num_buckets)
    return total / count