def get_center_loss(features, labels,featuress, labelss ,alpha, num_classes): len_features = features.get_shape()[1] # 建立一个Variable,shape为[num_classes, len_features],用于存储整个网络的样本中心, # 设置trainable=False是因为样本中心不是由梯度进行更新的 centers = tf.get_variable('centers', [num_classes, len_features], dtype=tf.float32, initializer=tf.constant_initializer(0), trainable=False) centerss = tf.get_variable('centerss', [num_classes, len_features], dtype=tf.float32, initializer=tf.constant_initializer(0), trainable=False) # 将label展开为一维的,输入如果已经是一维的,则该动作其实无必要 labels = tf.reshape(labels, [-1]) labelss = tf.reshape(labelss, [-1]) ############################################################## centers0=tf.unsorted_segment_mean(features,labels,num_classes) centers1=tf.unsorted_segment_mean(featuress,labelss,num_classes) EdgeWeights=tf.ones((num_classes,num_classes))-tf.eye(num_classes) margin=tf.constant(100,dtype="float32") norm = lambda x: tf.reduce_sum(tf.square(x), 1) center_pairwise_dist = tf.transpose(norm(tf.expand_dims(centers1, 2) - tf.transpose(centers0))) loss_0= tf.reduce_sum(tf.multiply(tf.maximum(0.0, margin-tf.transpose(norm(tf.expand_dims(centers0, 2) - tf.transpose(centers0)))),EdgeWeights)) # + tf.reduce_sum(tf.maximum(0.0, tf.pow((centers1 - centers0), 2))) \ # + 0.01*tf.reduce_sum(tf.multiply(tf.maximum(0.0, tf.constant(200,dtype="float32")-center_pairwise_dist),EdgeWeights)) # 根据样本label,获取mini-batch中每一个样本对应的中心值 centers_batch = tf.gather(centers, labels) # 当前mini-batch的特征值与它们对应的中心值之间的差 diff = centers_batch - features unique_label, unique_idx, unique_count = tf.unique_with_counts(labels) appear_times = tf.gather(unique_count, unique_idx) appear_times = tf.reshape(appear_times, [-1, 1]) diff = diff / tf.cast((1 + appear_times), tf.float32) diff = alpha * diff # # 根据样本label,获取mini-batch中每一个样本对应的中心值 # centers_batch1 = tf.gather(centerss, labelss) # # 当前mini-batch的特征值与它们对应的中心值之间的差 # diff1 = centers_batch1 - featuress # unique_label1, unique_idx1, unique_count1 = tf.unique_with_counts(labelss) # appear_times1 = tf.gather(unique_count1, unique_idx1) # appear_times1 = tf.reshape(appear_times1, [-1, 1]) # diff1 = diff1 / tf.cast((1 + appear_times1), tf.float32) # diff1 = alpha * diff1 # 计算loss loss_1 = tf.nn.l2_loss(features - centers_batch) centers_update_op= tf.scatter_sub(centers, labels, diff) # centers_update_op1= tf.scatter_sub(centerss, labelss, diff1) return loss_0, loss_1, centers_update_op
def grouped_pairwise_margin_loss(similarities, group_ids, labels, gamma=1.0): """ Calculates the pairwise margin ranking loss between paragraph of same question in a batch. Assumes at least (and most likely also at most) 2 paragraphs, one positive and one negative, for each question. """ _, group_segments = tf.unique(group_ids) num_segments = tf.reduce_max(group_segments) + 1 # we take advantage of the fact that negative segment ids get dropped # note: if label==1, 2*label-1 == 1, but if label==0, 2*label-1==-1 positive_ranks = tf.unsorted_segment_mean( similarities, (group_segments + 1) * (2 * labels - 1) - 1, num_segments=num_segments) negative_ranks = tf.unsorted_segment_mean( similarities, (group_segments + 1) * (1 - 2 * labels) - 1, num_segments=num_segments) return tf.maximum(gamma - positive_ranks + negative_ranks, 0.)
def get_center_loss(features, labels, alpha, num_classes): len_features = features.get_shape()[1] centers = tf.get_variable('centers', [num_classes, len_features], dtype=tf.float32, initializer=tf.constant_initializer(0), trainable=False) labels = tf.reshape(labels, [-1]) ############################################################## centers0 = tf.unsorted_segment_mean(features, labels, num_classes) EdgeWeights = tf.ones((num_classes, num_classes)) - tf.eye(num_classes) margin = tf.constant(100, dtype="float32") norm = lambda x: tf.reduce_sum(tf.square(x), 1) center_pairwise_dist = tf.transpose( norm(tf.expand_dims(centers0, 2) - tf.transpose(centers0))) loss_0 = tf.reduce_sum( tf.multiply(tf.maximum(0.0, margin - center_pairwise_dist), EdgeWeights)) ########################################################################### # 根据样本label,获取mini-batch中每一个样本对应的中心值 centers_batch = tf.gather(centers, labels) # 当前mini-batch的特征值与它们对应的中心值之间的差 diff = centers_batch - features unique_label, unique_idx, unique_count = tf.unique_with_counts(labels) appear_times = tf.gather(unique_count, unique_idx) appear_times = tf.reshape(appear_times, [-1, 1]) diff = diff / tf.cast((1 + appear_times), tf.float32) diff = alpha * diff # 计算loss loss_1 = tf.nn.l2_loss(features - centers_batch) centers_update_op = tf.scatter_sub(centers, labels, diff) return loss_0, loss_1, centers_update_op, centers
def get_variable_embeddings(all_sequence_embeddings): flat_sequence_embeddings = tf.reshape(all_sequence_embeddings, (-1, all_sequence_embeddings.get_shape()[-1])) # B*max-len x D target_token_embeddings = tf.gather(params=flat_sequence_embeddings, indices=self.placeholders['variable_bound_token_ids']) return tf.unsorted_segment_mean( data=target_token_embeddings, segment_ids=self.placeholders['token_variable_ids'], num_segments=self.placeholders['num_variables'] # TODO: Do not depend in any way on the classes. ) # num-variables x H
def classification_task_graphb4classify(self, last_h, classification_layer): _input = last_h # [v x h] graph_representations = tf.unsorted_segment_mean( data=_input, segment_ids=self.placeholders['graph_nodes_list'], num_segments=self.placeholders['num_graphs']) # [g x h] output = classification_layer(graph_representations) # [g x 2] self.output = output return output
def create_label(label, segments): n_node = tf.reduce_max(segments) + 1 node_label = tf.cast(tf.unsorted_segment_mean(label, segments, n_node), dtype=tf.float32) label_mean = tf.reduce_mean(tf.cast(label, dtype=tf.float32)) a = tf.ones_like(node_label, dtype=tf.int32) b = tf.zeros_like(node_label, dtype=tf.int32) condition = tf.less(label_mean, node_label) label_t = tf.where(condition, a, b) label_t = tf.squeeze(label_t, axis=1) return label_t
def classification_task_org(self, last_h, classification_layer): _input = last_h # [v x h] _output = classification_layer(_input) # [v x 2] # Sum up all nodes per-graph #graph_representations = tf.unsorted_segment_sum(data=_output, graph_representations = tf.unsorted_segment_mean( data=_output, segment_ids=self.placeholders['graph_nodes_list'], num_segments=self.placeholders['num_graphs']) # [g x 2] output = graph_representations self.output = output return output
def define_pooling(self, node_embeddings): if self.agg == 'sum': return tf.unsorted_segment_sum( data=node_embeddings, segment_ids=self.placeholders['node_graph_ids_list'], num_segments=self.placeholders['num_graphs']) elif self.agg == 'mean': return tf.unsorted_segment_mean( data=node_embeddings, segment_ids=self.placeholders['node_graph_ids_list'], num_segments=self.placeholders['num_graphs']) else: raise ValueError("Aggregation must be one of {'sum', 'mean'}")
def define_round(self, layer: int, time_step: int, node_embeddings): node_embeddings = tf.identity(node_embeddings) edge_weights = self.weights['edge_weights'][layer] src_node_ids, dst_node_ids, src_node_embeddings, dst_node_embeddings, messages = [], [], [], [], [] for e_type, adj_list in enumerate(self.placeholders['adjacency_lists']): src_node_ids.append(adj_list[:, 0]) dst_node_ids.append(adj_list[:, 1]) src_node_embeddings.append(tf.nn.embedding_lookup(params=node_embeddings, ids=src_node_ids[-1])) dst_node_embeddings.append(tf.nn.embedding_lookup(params=node_embeddings, ids=dst_node_ids[-1])) messages.append(tf.matmul(src_node_embeddings[-1], edge_weights[e_type])) src_node_ids = tf.concat(src_node_ids, axis=0) src_node_embeddings = tf.concat(src_node_embeddings, axis=0) dst_node_ids = tf.concat(dst_node_ids, axis=0) dst_node_embeddings = tf.concat(dst_node_embeddings, axis=0) messages = tf.concat(messages, axis=0) # Now weigh the messages using attention if configured to do so if self.use_propagation_attention: messages *= tf.expand_dims(self.define_message_attention(layer, src_node_ids, src_node_embeddings, dst_node_ids, dst_node_embeddings, messages), axis=-1) # Accumulate all messages for a destination node if self.edge_msg_aggregation == 'avg': incoming_messages = tf.unsorted_segment_mean(data=messages, segment_ids=dst_node_ids, num_segments=self.placeholders['num_nodes']) elif self.edge_msg_aggregation == 'sum': incoming_messages = tf.unsorted_segment_sum(data=messages, segment_ids=dst_node_ids, num_segments=self.placeholders['num_nodes']) else: raise ValueError("Edge message aggregation type should be one of {'avg', 'sum'}") # Compute new node states i.e. states for the next round of message passing (if any) return self.weights['rnn_cells'][layer](incoming_messages, [node_embeddings])[0]
def update_memory(self, solver): # update memory bank after solver with tf.control_dependencies([solver]): with tf.name_scope('update_point_memory'): feature = self.feature seg_num, point_id = self.flags.seg_num, self.point_id # point_mask = point_id > -1 # filter label -1 point_id = point_id + (self.obj_segment * seg_num) # point_id = tf.boolean_mask(point_id, point_mask) # feature = tf.boolean_mask(feature, point_mask) batch_size = self.batch_size feature = tf.unsorted_segment_mean(feature, point_id, seg_num * batch_size) feature = tf.nn.l2_normalize(feature, axis=1) feature = tf.reshape(feature, [batch_size, seg_num, -1]) momentum = self.flags.momentum weight = tf.gather(self.memory, self.shape_id) weight = feature * momentum + weight * (1 - momentum) weight = tf.nn.l2_normalize(weight, axis=2) memory = tf.scatter_update(self.memory, self.shape_id, weight) return memory
def construct(self, v_dims=[2, 4, 8], e_dims=[2, 4, 8], g_dims=[4, 8, 16]): with self.session.graph.as_default(): assert(len(v_dims) == len(e_dims)) assert(len(v_dims) == len(g_dims)) # Input graph order and size self.g_n = tf.placeholder(tf.int32, [], name="g_n") self.g_n_f = tf.cast(self.g_n, tf.float32) self.g_m = tf.placeholder(tf.int32, [], name="g_m") self.g_m_f = tf.cast(self.g_m, tf.float32) # Input graph edge sources and destinations self.edge_srcs = tf.placeholder(tf.int32, [None], name="edge_srcs") self.edge_dsts = tf.placeholder(tf.int32, [None], name="edge_dsts") # Initial values self.v_layer = tf.ones([self.g_n, 1]) self.g_layer = tf.ones([1, 0]) self.e_layer = tf.ones([self.g_m, 0]) for i in range(len(v_dims)): # For every edge, the input and output vertex e_in_layer = tf.gather(self.v_layer, self.edge_srcs, axis=0) e_out_layer = tf.gather(self.v_layer, self.edge_dsts, axis=0) e_layer_new = self.edge_gadget(self.g_layer, self.e_layer, e_in_layer, e_out_layer, dim=e_dims[i]) v_sum = tf.unsorted_segment_sum(e_layer_new, self.edge_dsts, self.g_n) v_max = tf.maximum(tf.unsorted_segment_max(e_layer_new, self.edge_dsts, self.g_n), 0.0) v_mean = tf.unsorted_segment_mean(e_layer_new, self.edge_dsts, self.g_n) v_layer_new = self.vertex_gadget(self.g_layer, self.v_layer, v_sum, v_mean, v_max, dim=v_dims[i]) g_layer_new = self.graph_gadget(self.g_layer, self.v_layer, dim=g_dims[i]) self.g_layer = g_layer_new self.v_layer = v_layer_new self.e_layer = e_layer_new
def main(unused_args): assert len(unused_args) == 1, unused_args setup_experiment(logging, FLAGS, "critic_model") if FLAGS.validation: mnist_ds = mnist.read_data_sets(FLAGS.data_dir, dtype=tf.float32, reshape=False, validation_size=0) val_ds = mnist_ds.test else: mnist_ds = mnist.read_data_sets(FLAGS.data_dir, dtype=tf.float32, reshape=False, validation_size=FLAGS.validation_size) val_ds = mnist_ds.validation train_ds = mnist_ds.train val_ds = mnist_ds.validation test_ds = mnist_ds.test num_classes = FLAGS.num_classes img_shape = [None, 1, 28, 28] X = tf.placeholder(tf.float32, shape=img_shape, name='X') # placeholder to avoid recomputation of adversarial images for critic X_hat_h = tf.placeholder(tf.float32, shape=img_shape, name='X_hat') y = tf.placeholder(tf.int32, shape=[None], name='y') y_onehot = tf.one_hot(y, num_classes) reduce_ind = list(range(1, X.get_shape().ndims)) # test/validation inputs X_v = tf.placeholder(tf.float32, shape=img_shape, name='X_v') y_v = tf.placeholder(tf.int32, shape=[None], name='y_v') y_v_onehot = tf.one_hot(y_v, num_classes) # classifier model model = create_model(FLAGS, name=FLAGS.model_name) def test_model(x, **kwargs): return model(x, train=False, **kwargs) # generator def generator(inputs, confidence, targets=None): return high_confidence_attack_unrolled( lambda x: model(x)['logits'], inputs, targets=targets, confidence=confidence, max_iter=FLAGS.attack_iter, over_shoot=FLAGS.attack_overshoot, attack_random=FLAGS.attack_random, attack_uniform=FLAGS.attack_uniform, attack_label_smoothing=FLAGS.attack_label_smoothing) def test_generator(inputs, confidence, targets=None): return high_confidence_attack(lambda x: test_model(x)['logits'], inputs, targets=targets, confidence=confidence, max_iter=FLAGS.df_iter, over_shoot=FLAGS.df_overshoot, random=FLAGS.attack_random, uniform=FLAGS.attack_uniform, clip_dist=FLAGS.df_clip) # discriminator critic = create_model(FLAGS, prefix='critic_', name='critic') # classifier outputs outs_x = model(X) outs_x_v = test_model(X_v) params = tf.trainable_variables() model_weights = [param for param in params if "weights" in param.name] vars = tf.model_variables() target_conf_v = [None] if FLAGS.attack_confidence == "same": # set the target confidence to the confidence of the original prediction target_confidence = outs_x['conf'] target_conf_v[0] = target_confidence elif FLAGS.attack_confidence == "class_running_mean": # set the target confidence to the mean confidence of the specific target # use running mean estimate class_conf_mean = tf.Variable(np.ones(num_classes, dtype=np.float32)) batch_conf_mean = tf.unsorted_segment_mean(outs_x['conf'], outs_x['pred'], num_classes) # if batch does not contain predictions for the specific target # (zeroes), replace zeroes with stored class mean (previous batch) batch_conf_mean = tf.where(tf.not_equal(batch_conf_mean, 0), batch_conf_mean, class_conf_mean) # update class confidence mean class_conf_mean = assign_moving_average(class_conf_mean, batch_conf_mean, 0.5) # init class confidence during pre-training tf.add_to_collection("PREINIT_OPS", class_conf_mean) def target_confidence(targets_onehot): targets = tf.argmax(targets_onehot, axis=1) check_conf = tf.Assert( tf.reduce_all(tf.not_equal(class_conf_mean, 0)), [class_conf_mean]) with tf.control_dependencies([check_conf]): t = tf.gather(class_conf_mean, targets) target_conf_v[0] = t return tf.stop_gradient(t) else: target_confidence = float(FLAGS.attack_confidence) target_conf_v[0] = target_confidence X_hat = generator(X, target_confidence) outs_x_hat = model(X_hat) # select examples for which attack succeeded (changed the prediction) X_hat_filter = tf.not_equal(outs_x['pred'], outs_x_hat['pred']) X_hat_f = tf.boolean_mask(X_hat, X_hat_filter) X_f = tf.boolean_mask(X, X_hat_filter) outs_x_f = model(X_f) outs_x_hat_f = model(X_hat_f) X_hatd = tf.stop_gradient(X_hat) X_rec = generator(X_hatd, outs_x['conf'], outs_x['pred']) X_rec_f = tf.boolean_mask(X_rec, X_hat_filter) # validation/test adversarial examples X_v_hat = test_generator(X_v, FLAGS.val_attack_confidence) X_v_hatd = tf.stop_gradient(X_v_hat) X_v_rec = test_generator(X_v_hatd, outs_x_v['conf'], targets=outs_x_v['pred']) X_v_hat_df = deepfool(lambda x: test_model(x)['logits'], X_v, y_v, max_iter=FLAGS.df_iter, clip_dist=FLAGS.df_clip) X_v_hat_df_all = deepfool(lambda x: test_model(x)['logits'], X_v, max_iter=FLAGS.df_iter, clip_dist=FLAGS.df_clip) y_hat = outs_x['pred'] y_adv = outs_x_hat['pred'] y_adv_f = outs_x_hat_f['pred'] tf.summary.histogram('y_data', y, collections=["model_summaries"]) tf.summary.histogram('y_hat', y_hat, collections=["model_summaries"]) tf.summary.histogram('y_adv', y_adv, collections=["model_summaries"]) # critic outputs critic_outs_x = critic(X) critic_outs_x_hat = critic(X_hat_f) critic_params = list(set(tf.trainable_variables()) - set(params)) critic_vars = list(set(tf.trainable_variables()) - set(vars)) # binary logits for a specific target logits_data = critic_outs_x['logits'] logits_data_flt = tf.reshape(logits_data, (-1, )) z_data = tf.gather(logits_data_flt, tf.range(tf.shape(X)[0]) * num_classes + y) logits_adv = critic_outs_x_hat['logits'] logits_adv_flt = tf.reshape(logits_adv, (-1, )) z_adv = tf.gather(logits_adv_flt, tf.range(tf.shape(X_hat_f)[0]) * num_classes + y_adv_f) # classifier/generator losses nll = tf.reduce_mean( tf.losses.softmax_cross_entropy(y_onehot, outs_x['logits'])) nll_v = tf.reduce_mean( tf.losses.softmax_cross_entropy(y_v_onehot, outs_x_v['logits'])) # gan losses gan = tf.losses.sigmoid_cross_entropy(tf.ones_like(z_adv), z_adv) rec_l1 = tf.reduce_mean( tf.reduce_sum(tf.abs(X_f - X_rec_f), axis=reduce_ind)) rec_l2 = tf.reduce_mean(tf.reduce_sum((X_f - X_rec_f)**2, axis=reduce_ind)) weight_decay = slim.apply_regularization(slim.l2_regularizer(1.0), model_weights[:-1]) pretrain_loss = nll + 5e-6 * weight_decay loss = nll + FLAGS.lmbd * gan if FLAGS.lmbd_rec_l1 > 0: loss += FLAGS.lmbd_rec_l1 * rec_l1 if FLAGS.lmbd_rec_l2 > 0: loss += FLAGS.lmbd_rec_l2 * rec_l2 if FLAGS.weight_decay > 0: loss += FLAGS.weight_decay * weight_decay # critic loss critic_gan_data = tf.losses.sigmoid_cross_entropy(tf.ones_like(z_data), z_data) # use placeholder for X_hat to avoid recomputation of adversarial noise y_adv_h = model(X_hat_h)['pred'] logits_adv_h = critic(X_hat_h)['logits'] logits_adv_flt_h = tf.reshape(logits_adv_h, (-1, )) z_adv_h = tf.gather(logits_adv_flt_h, tf.range(tf.shape(X_hat_h)[0]) * num_classes + y_adv_h) critic_gan_adv = tf.losses.sigmoid_cross_entropy(tf.zeros_like(z_adv_h), z_adv_h) critic_gan = critic_gan_data + critic_gan_adv # Gulrajani discriminator regularizer (we do not interpolate) critic_grad_data = tf.gradients(z_data, X)[0] critic_grad_adv = tf.gradients(z_adv_h, X_hat_h)[0] critic_grad_penalty = norm_penalty(critic_grad_adv) + norm_penalty( critic_grad_data) critic_loss = critic_gan + FLAGS.lmbd_grad * critic_grad_penalty # classifier model_metrics err = 1 - slim.metrics.accuracy(outs_x['pred'], y) conf = tf.reduce_mean(outs_x['conf']) err_hat = 1 - slim.metrics.accuracy( test_model(X_hat)['pred'], outs_x['pred']) err_hat_f = 1 - slim.metrics.accuracy( test_model(X_hat_f)['pred'], outs_x_f['pred']) err_rec = 1 - slim.metrics.accuracy( test_model(X_rec)['pred'], outs_x['pred']) conf_hat = tf.reduce_mean(test_model(X_hat)['conf']) conf_hat_f = tf.reduce_mean(test_model(X_hat_f)['conf']) conf_rec = tf.reduce_mean(test_model(X_rec)['conf']) err_v = 1 - slim.metrics.accuracy(outs_x_v['pred'], y_v) conf_v_hat = tf.reduce_mean(test_model(X_v_hat)['conf']) l2_hat = tf.sqrt(tf.reduce_sum((X_f - X_hat_f)**2, axis=reduce_ind)) tf.summary.histogram('l2_hat', l2_hat, collections=["model_summaries"]) # critic model_metrics critic_err_data = 1 - binary_accuracy( z_data, tf.ones(tf.shape(z_data), tf.bool), 0.0) critic_err_adv = 1 - binary_accuracy( z_adv, tf.zeros(tf.shape(z_adv), tf.bool), 0.0) # validation model_metrics err_df = 1 - slim.metrics.accuracy(test_model(X_v_hat_df)['pred'], y_v) err_df_all = 1 - slim.metrics.accuracy( test_model(X_v_hat_df_all)['pred'], outs_x_v['pred']) l2_v_hat = tf.sqrt(tf.reduce_sum((X_v - X_v_hat)**2, axis=reduce_ind)) l2_v_rec = tf.sqrt(tf.reduce_sum((X_v - X_v_rec)**2, axis=reduce_ind)) l1_v_rec = tf.reduce_sum(tf.abs(X_v - X_v_rec), axis=reduce_ind) l2_df = tf.sqrt(tf.reduce_sum((X_v - X_v_hat_df)**2, axis=reduce_ind)) l2_df_norm = l2_df / tf.sqrt(tf.reduce_sum(X_v**2, axis=reduce_ind)) l2_df_all = tf.sqrt( tf.reduce_sum((X_v - X_v_hat_df_all)**2, axis=reduce_ind)) l2_df_norm_all = l2_df_all / tf.sqrt(tf.reduce_sum(X_v**2, axis=reduce_ind)) tf.summary.histogram('l2_df', l2_df, collections=["adv_summaries"]) tf.summary.histogram('l2_df_norm', l2_df_norm, collections=["adv_summaries"]) # model_metrics pretrain_model_metrics = OrderedDict([('nll', nll), ('weight_decay', weight_decay), ('err', err)]) model_metrics = OrderedDict([('loss', loss), ('nll', nll), ('l2_hat', tf.reduce_mean(l2_hat)), ('gan', gan), ('rec_l1', rec_l1), ('rec_l2', rec_l2), ('weight_decay', weight_decay), ('err', err), ('conf', conf), ('err_hat', err_hat), ('err_hat_f', err_hat_f), ('conf_t', tf.reduce_mean(target_conf_v[0])), ('conf_hat', conf_hat), ('conf_hat_f', conf_hat_f), ('err_rec', err_rec), ('conf_rec', conf_rec)]) critic_metrics = OrderedDict([('c_loss', critic_loss), ('c_gan', critic_gan), ('c_gan_data', critic_gan_data), ('c_gan_adv', critic_gan_adv), ('c_grad_norm', critic_grad_penalty), ('c_err_adv', critic_err_adv), ('c_err_data', critic_err_data)]) val_metrics = OrderedDict([('nll', nll_v), ('err', err_v)]) adv_metrics = OrderedDict([('l2_df', tf.reduce_mean(l2_df)), ('l2_df_norm', tf.reduce_mean(l2_df_norm)), ('l2_df_all', tf.reduce_mean(l2_df_all)), ('l2_df_all_norm', tf.reduce_mean(l2_df_norm_all)), ('l2_hat', tf.reduce_mean(l2_v_hat)), ('conf_hat', conf_v_hat), ('l1_rec', tf.reduce_mean(l1_v_rec)), ('l2_rec', tf.reduce_mean(l2_v_rec)), ('err_df', err_df), ('err_df_all', err_df_all)]) pretrain_metric_mean, pretrain_metric_upd = register_metrics( pretrain_model_metrics, collections="pretrain_model_summaries") metric_mean, metric_upd = register_metrics(model_metrics, collections="model_summaries") critic_metric_mean, critic_metric_upd = register_metrics( critic_metrics, collections="critic_summaries") val_metric_mean, val_metric_upd = register_metrics( val_metrics, prefix="val_", collections="val_summaries") adv_metric_mean, adv_metric_upd = register_metrics( adv_metrics, collections="adv_summaries") metrics_reset = tf.variables_initializer(tf.local_variables()) # training ops lr = tf.Variable(FLAGS.lr, trainable=False) critic_lr = tf.Variable(FLAGS.critic_lr, trainable=False) tf.summary.scalar('lr', lr, collections=["model_summaries"]) tf.summary.scalar('critic_lr', critic_lr, collections=["critic_summaries"]) optimizer = tf.train.AdamOptimizer(learning_rate=lr, beta1=0.5) preinit_ops = tf.get_collection("PREINIT_OPS") with tf.control_dependencies(preinit_ops): pretrain_solver = optimizer.minimize(pretrain_loss, var_list=params) solver = optimizer.minimize(loss, var_list=params) critic_solver = (tf.train.AdamOptimizer( learning_rate=critic_lr, beta1=0.5).minimize(critic_loss, var_list=critic_params)) # train summary_images, summary_labels = select_balanced_subset( train_ds.images, train_ds.labels, num_classes, num_classes) summary_images = summary_images.transpose((0, 3, 1, 2)) save_path = os.path.join(FLAGS.samples_dir, 'orig.png') save_images(summary_images, save_path) if FLAGS.gpu_memory < 1.0: gpu_options = tf.GPUOptions( per_process_gpu_memory_fraction=FLAGS.gpu_memory) config = tf.ConfigProto(gpu_options=gpu_options) else: config = None with tf.Session(config=config) as sess: try: # summaries summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph) summaries = tf.summary.merge_all("model_summaries") critic_summaries = tf.summary.merge_all("critic_summaries") val_summaries = tf.summary.merge_all("val_summaries") adv_summaries = tf.summary.merge_all("adv_summaries") # initialization tf.local_variables_initializer().run() tf.global_variables_initializer().run() # pretrain model if FLAGS.pretrain_niter > 0: logging.info("Model pretraining") for epoch in range(1, FLAGS.pretrain_niter + 1): train_iterator = batch_iterator(train_ds.images, train_ds.labels, FLAGS.batch_size, shuffle=True) sess.run(metrics_reset) start_time = time.time() for ind, (images, labels) in enumerate(train_iterator): sess.run([pretrain_solver, pretrain_metric_upd], feed_dict={ X: images, y: labels }) str_bfr = six.StringIO() str_bfr.write("Pretrain epoch [{}, {:.2f}s]:".format( epoch, time.time() - start_time)) print_results_str(str_bfr, pretrain_model_metrics.keys(), sess.run(pretrain_metric_mean)) print_results_str(str_bfr, critic_metrics.keys(), sess.run(critic_metric_mean)) logging.info(str_bfr.getvalue()[:-1]) # training for epoch in range(1, FLAGS.niter + 1): train_iterator = batch_iterator(train_ds.images, train_ds.labels, FLAGS.batch_size, shuffle=True) sess.run(metrics_reset) start_time = time.time() for ind, (images, labels) in enumerate(train_iterator): batch_index = (epoch - 1) * (train_ds.images.shape[0] // FLAGS.batch_size) + ind # train critic for several steps X_hat_np = sess.run(X_hat, feed_dict={X: images}) for _ in range(FLAGS.critic_steps - 1): sess.run([critic_solver], feed_dict={ X: images, y: labels, X_hat_h: X_hat_np }) else: summary = sess.run([ critic_solver, critic_metric_upd, critic_summaries ], feed_dict={ X: images, y: labels, X_hat_h: X_hat_np })[-1] summary_writer.add_summary(summary, batch_index) # train model summary = sess.run([solver, metric_upd, summaries], feed_dict={ X: images, y: labels })[-1] summary_writer.add_summary(summary, batch_index) str_bfr = six.StringIO() str_bfr.write("Train epoch [{}, {:.2f}s]:".format( epoch, time.time() - start_time)) print_results_str(str_bfr, model_metrics.keys(), sess.run(metric_mean)) print_results_str(str_bfr, critic_metrics.keys(), sess.run(critic_metric_mean)) logging.info(str_bfr.getvalue()[:-1]) val_iterator = batch_iterator(val_ds.images, val_ds.labels, 100, shuffle=False) for images, labels in val_iterator: summary = sess.run([val_metric_upd, val_summaries], feed_dict={ X_v: images, y_v: labels })[-1] summary_writer.add_summary(summary, epoch) str_bfr = six.StringIO() str_bfr.write("Valid epoch [{}]:".format(epoch)) print_results_str(str_bfr, val_metrics.keys(), sess.run(val_metric_mean)) logging.info(str_bfr.getvalue()[:-1]) # learning rate decay update_lr = lr_decay(lr, epoch) if update_lr is not None: sess.run(update_lr) logging.debug( "learning rate was updated to: {:.10f}".format( lr.eval())) critic_update_lr = lr_decay(critic_lr, epoch, prefix='critic_') if critic_update_lr is not None: sess.run(critic_update_lr) logging.debug( "critic learning rate was updated to: {:.10f}".format( critic_lr.eval())) if epoch % FLAGS.summary_frequency == 0: samples_hat, samples_rec, samples_df, summary = sess.run( [ X_v_hat, X_v_rec, X_v_hat_df, adv_summaries, adv_metric_upd ], feed_dict={ X_v: summary_images, y_v: summary_labels })[:-1] summary_writer.add_summary(summary, epoch) save_path = os.path.join(FLAGS.samples_dir, 'epoch_orig-%d.png' % epoch) save_images(summary_images, save_path) save_path = os.path.join(FLAGS.samples_dir, 'epoch-%d.png' % epoch) save_images(samples_hat, save_path) save_path = os.path.join(FLAGS.samples_dir, 'epoch_rec-%d.png' % epoch) save_images(samples_rec, save_path) save_path = os.path.join(FLAGS.samples_dir, 'epoch_df-%d.png' % epoch) save_images(samples_df, save_path) str_bfr = six.StringIO() str_bfr.write("Summary epoch [{}]:".format(epoch)) print_results_str(str_bfr, adv_metrics.keys(), sess.run(adv_metric_mean)) logging.info(str_bfr.getvalue()[:-1]) if FLAGS.checkpoint_frequency != -1 and epoch % FLAGS.checkpoint_frequency == 0: save_checkpoint(sess, vars, epoch=epoch) save_checkpoint(sess, critic_vars, name="critic_model", epoch=epoch) except KeyboardInterrupt: logging.debug("Keyboard interrupt. Stopping training...") except NanError as e: logging.info(e) finally: sess.run(metrics_reset) save_checkpoint(sess, vars) save_checkpoint(sess, critic_vars, name="critic_model") # final accuracy test_iterator = batch_iterator(test_ds.images, test_ds.labels, 100, shuffle=False) for images, labels in test_iterator: sess.run([val_metric_upd], feed_dict={X_v: images, y_v: labels}) str_bfr = six.StringIO() str_bfr.write("Final epoch [{}]:".format(epoch)) for metric_name, metric_value in zip(val_metrics.keys(), sess.run(val_metric_mean)): str_bfr.write(" {}: {:.6f},".format(metric_name, metric_value)) logging.info(str_bfr.getvalue()[:-1])
def create_prot_pooling(self, pInFeatures, pProtein, pLevel, pBNAFDO): """Method to create a protein pooling operation. Args: pInFeatures (float tensor nxf): Input features. pProtein (Protein): Protein. pLevel (int): Level we want to pool. pBNAFDO (BNAFDO): BNAFDO object. Returns: (float tensor n'xf): Pooled features to the level pLevel+1. """ if pProtein.poolType_[pLevel - 1] == "GRA": poolFeatures = self.graphConvBuilder_.create_graph_aggregation( pInFeatures=pInFeatures, pGraph=pProtein.molObjects_[pLevel - 1].graph_, pNormalize=True, pSpectralApprox=False) poolFeatures = tf.gather(poolFeatures, pProtein.poolIds_[pLevel - 1]) elif pProtein.poolType_[pLevel - 1] == "AVG": poolFeatures = tf.unsorted_segment_mean( pInFeatures, pProtein.poolIds_[pLevel - 1], tf.shape(pProtein.molObjects_[pLevel].batchIds_)[0]) elif pProtein.poolType_[pLevel - 1].startswith("GRAPH_DROP"): maskValueBool, poolFeatures, newGraph = self.graphConvBuilder_.create_graph_node_pooling( "Graph_drop_pooling_" + str(pLevel), pProtein.molObjects_[pLevel - 1].batchIds_, pProtein.molObjects_[pLevel - 1].graph_, pInFeatures, pProtein.molObjects_[pLevel - 1].batchSize_, 0.5, pBNAFDO) newPos = tf.boolean_mask(pProtein.molObjects_[pLevel - 1].pc_.pts_, maskValueBool) newBatchIds = tf.boolean_mask( pProtein.molObjects_[pLevel - 1].batchIds_, maskValueBool) if pProtein.molObjects_[pLevel - 1].graph2_ is None: newGraph2 = Graph(None, None) else: newGraph2 = pProtein.molObjects_[ pLevel - 1].graph2_.pool_graph_drop_nodes( maskValueBool, tf.shape(newPos)[0]) pProtein.molObjects_[pLevel] = Molecule( newPos, newGraph.neighbors_, newGraph.nodeStartIndexs_, newBatchIds, pProtein.molObjects_[pLevel - 1].batchSize_, newGraph2.neighbors_, newGraph2.nodeStartIndexs_) if pProtein.poolType_[pLevel - 1] == "GRAPH_DROP_AMINO": pProtein.poolIds_[pLevel] = tf.boolean_mask( pProtein.atomAminoIds_, maskValueBool) elif pProtein.poolType_[pLevel - 1].startswith("GRAPH_EDGE"): newIndices, poolFeatures, newGraph = self.graphConvBuilder_.create_graph_edge_pooling( "Graph_edge_pooling_" + str(pLevel), pProtein.molObjects_[pLevel - 1].graph_, pInFeatures, pBNAFDO) newPos = tf.unsorted_segment_mean( pProtein.molObjects_[pLevel - 1].pc_.pts_, newIndices, tf.shape(poolFeatures)[0]) newBatchIds = tf.unsorted_segment_max( pProtein.molObjects_[pLevel - 1].batchIds_, newIndices, tf.shape(poolFeatures)[0]) if pProtein.molObjects_[pLevel - 1].graph2_ is None: newGraph2 = Graph(None, None) else: newGraph2 = pProtein.molObjects_[ pLevel - 1].graph2_.pool_graph_collapse_edges( newIndices, tf.shape(newPos)[0]) pProtein.molObjects_[pLevel] = Molecule( newPos, newGraph.neighbors_, newGraph.nodeStartIndexs_, newBatchIds, pProtein.molObjects_[pLevel - 1].batchSize_, newGraph2.neighbors_, newGraph2.nodeStartIndexs_) if pProtein.poolType_[pLevel - 1] == "GRAPH_EDGE_AMINO": pProtein.poolIds_[pLevel] = tf.unsorted_segment_max( pProtein.atomAminoIds_, newIndices, tf.shape(poolFeatures)[0]) return poolFeatures
def __bio_pooling__(self, pAminoInput, pAtomPos, pBatchIds, pPoolIds, pGraph1Neighbors, pGraph1NeighStartIds, pGraph2Neighbors, pGraph2NeighStartIds, pBatchSize, pConfig, pAminoPos = None, pAtomAminoIds = None, pAtomResidueIds = None, pNumResidues = None): """Biological inspired pooling. Args: pAminoInput (bool): Boolean that indicates if the protein is represented at aminoacid level. pAtomPos (float tensor nxd): List of atom positions. pBatchIds (int tensor n): List of batch ids. pPoolIds (list int tensor n): List of pool ids for each level. pGraph1Neighbors (list int tensor mx2): List of neighbor pairs. pGraph1NeighStartIds (list int tensor n): List of starting indices for each atom in the neighboring list. pGraph2Neighbors (list int tensor mx2): List of neighbor pairs. pGraph2NeighStartIds (list int tensor n): List of starting indices for each atom in the neighboring list. pBatchSize (int): Size of the batch. pConfig (dictionary): Dictionary with the config parameters. pAminoPos (float tensor n'x3): Aminoacid pos. pAtomAminoIds (int tensor n): Identifier of the aminoacid per each atom. pAtomResidueIds (int tensor n): Identifier of the residue per each atom. pNumResidues (int): Number of residues. """ numBBPooling = int(pConfig['prot.numbbpoolings']) self.aminoInput_ = pAminoInput self.poolIds_ = [curPoolId for curPoolId in pPoolIds] # Save the first molecule object. self.molObjects_ = [ Molecule( pAtomPos, pGraph1Neighbors[0], pGraph1NeighStartIds[0], pBatchIds, pBatchSize, pGraph2Neighbors[0], pGraph2NeighStartIds[0])] # If input not aminoacids. if not self.aminoInput_: # Save side chain poolings. curAtomAminoIds = pAtomAminoIds curPoolIds = pPoolIds[0] newPos = tf.unsorted_segment_mean( self.molObjects_[0].atomPos_, curPoolIds, tf.shape(pGraph1NeighStartIds[1])[0]) curAtomAminoIds = tf.unsorted_segment_max( curAtomAminoIds, curPoolIds, tf.shape(pGraph1NeighStartIds[1])[0]) newBatchIds = tf.unsorted_segment_max( self.molObjects_[0].batchIds_, curPoolIds, tf.shape(pGraph1NeighStartIds[1])[0]) self.molObjects_.append( Molecule( newPos, pGraph1Neighbors[1], pGraph1NeighStartIds[1], newBatchIds, pBatchSize, pGraph2Neighbors[1], pGraph2NeighStartIds[1])) self.poolType_.append('AVG') # Save aminoacid level. aminoBatchIds = tf.unsorted_segment_max( pBatchIds, pAtomAminoIds, tf.shape(pAminoPos)[0]) self.molObjects_.append( Molecule( pAminoPos, pGraph1Neighbors[-1], pGraph1NeighStartIds[-1], aminoBatchIds, pBatchSize, pGraph2Neighbors[-1], pGraph2NeighStartIds[-1])) self.poolIds_.append(curAtomAminoIds) self.poolType_.append('AVG') # Compute the backbone poolings. for curPool in range(numBBPooling): selIndices, pooledNeighs, pooledStartIds = \ compute_protein_pooling(self.molObjects_[-1].graph_) newPositions = compute_graph_aggregation( self.molObjects_[-1].graph_, self.molObjects_[-1].atomPos_, True) newPositions = tf.gather(newPositions, selIndices) self.poolIds_.append(selIndices) self.poolType_.append('GRA') selMask = tf.scatter_nd( tf.reshape(selIndices, [-1, 1]), tf.ones_like(selIndices), tf.shape(self.molObjects_[-1].batchIds_)) pooledIndices = tf.cumsum(selMask)-1 # Create new graph2. newGraph2 = self.molObjects_[-1].graph2_.pool_graph_collapse_edges( pooledIndices, tf.shape(selIndices)[0]) pooledNeighs2 = newGraph2.neighbors_ pooledStartIds2 = newGraph2.nodeStartIndexs_ newBatchIds = tf.gather(self.molObjects_[-1].batchIds_, selIndices) self.molObjects_.append(Molecule( newPositions, pooledNeighs, pooledStartIds, newBatchIds, pBatchSize, pooledNeighs2, pooledStartIds2)) self.atomAminoIds_ = pAtomAminoIds self.atomResidueIds_ = pAtomResidueIds self.numResidues_ = pNumResidues
def bucket_mean(data, bucket_ids, num_buckets): total = tf.unsorted_segment_mean(data, bucket_ids, num_buckets) count = tf.unsorted_segment_mean(tf.ones_like(data), bucket_ids, num_buckets) return total / count