Exemplo n.º 1
0
    def fit(self, X, y=None):
        '''Contrastive Divergence training procedure'''

        self._initialize_weights(X)

        self.training_errors = []
        self.training_reconstructions = []
        for _ in self.progressbar(range(self.n_iterations)):
            batch_errors = []
            for batch in batch_iterator(X, batch_size=self.batch_size):
                # Positive phase
                positive_hidden = sigmoid(batch.dot(self.W) + self.h0)
                hidden_states = self._sample(positive_hidden)
                positive_associations = batch.T.dot(positive_hidden)

                # Negative phase
                negative_visible = sigmoid(hidden_states.dot(self.W.T) + self.v0)
                negative_visible = self._sample(negative_visible)
                negative_hidden = sigmoid(negative_visible.dot(self.W) + self.h0)
                negative_associations = negative_visible.T.dot(negative_hidden)

                self.W  += self.lr * (positive_associations - negative_associations)
                self.h0 += self.lr * (positive_hidden.sum(axis=0) - negative_hidden.sum(axis=0))
                self.v0 += self.lr * (batch.sum(axis=0) - negative_visible.sum(axis=0))

                batch_errors.append(np.mean((batch - negative_visible) ** 2))

            self.training_errors.append(np.mean(batch_errors))
            # Reconstruct a batch of images from the training set
            idx = np.random.choice(range(X.shape[0]), self.batch_size)
            self.training_reconstructions.append(self.reconstruct(X[idx]))
Exemplo n.º 2
0
    def fit(self,
            x_train: np.array,
            y_train: np.array,
            x_test: np.array,
            y_test: np.array,
            epochs: np.int,
            batch_size=16,
            _print=False,
            history=False):
        train_history = []
        test_history = []
        loss_history = []
        for i in range(epochs):
            if _print or history:
                acc_train = self.binary_accuracy(x_train, y_train)
                acc_test = self.binary_accuracy(x_test, y_test)
                train_loss = self.loss(x_train, y_train)
            if _print:
                print(f'Befor epoch {i + 1}:')
                print(f'Accuracy on train set: {round(acc_train * 100, 2)} %')
                print(f'Accuracy on test set: {round(acc_test * 100, 2)} %')
                print(f'Loss: {round(train_loss, 2)}')
            if history:
                train_history.append(acc_train)
                test_history.append(acc_test)
                loss_history.append(train_loss)

            for x_batch, y_batch in batch_iterator(x_train,
                                                   y_train,
                                                   batch_size,
                                                   stochastic=True):
                self.layers = self.optimizer.optimize(self, x_batch, y_batch)
        return train_history, test_history, loss_history
Exemplo n.º 3
0
    def fit(self, X, y, n_epochs, batch_size):
        """ Trains the model for a fixed number of epochs """
        for _ in self.progressbar(range(n_epochs)):

            batch_error = []
            for X_batch, y_batch in batch_iterator(X, y,
                                                   batch_size=batch_size):
                loss, _ = self.train_on_batch(X_batch, y_batch)
                batch_error.append(loss)

            self.errors["training"].append(np.mean(batch_error))

            if self.val_set is not None:
                val_loss, _ = self.test_on_batch(self.val_set["X"],
                                                 self.val_set["y"])
                self.errors["validation"].append(val_loss)

        return self.errors["training"], self.errors["validation"]
Exemplo n.º 4
0
    def create_domains_corpus(self, file_in_name, file_out_name,
                              batch_num_lines):
        """
		Create domain corpus from protein domains tabular file

		Parameters
		----------
		file_in_name : str
			input file name
		file_out_name : str
			output file name
		batch_num_lines : int
			number of lines to be processed per batch

		Returns
		-------
		None
		"""
        total_out_lines = 0
        with open(os.path.join(self.data_path, file_in_name),
                  'r') as file_in, open(
                      os.path.join(self.data_path, file_out_name),
                      'a') as file_out:
            for i, batch in enumerate(batch_iterator(file_in,
                                                     batch_num_lines)):
                for line in batch:
                    line_tabs = line.split("\t")
                    assert len(
                        line_tabs
                    ) == 3, "AssertionError: line should have only three tabs."
                    protein_domains = line_tabs[1]
                    if protein_domains.strip() != "interpro_ids":
                        file_out.write(protein_domains + "\n")
                        total_out_lines = total_out_lines + 1
        print("Successfully written {} proteins in domains representation.".
              format(total_out_lines))
Exemplo n.º 5
0
def main(unused_args):
    assert len(unused_args) == 1, unused_args
    setup_experiment(logging, FLAGS, "critic_model")

    if FLAGS.validation:
        mnist_ds = mnist.read_data_sets(FLAGS.data_dir,
                                        dtype=tf.float32,
                                        reshape=False,
                                        validation_size=0)
        val_ds = mnist_ds.test
    else:
        mnist_ds = mnist.read_data_sets(FLAGS.data_dir,
                                        dtype=tf.float32,
                                        reshape=False,
                                        validation_size=FLAGS.validation_size)
        val_ds = mnist_ds.validation
    train_ds = mnist_ds.train
    val_ds = mnist_ds.validation
    test_ds = mnist_ds.test
    num_classes = FLAGS.num_classes

    img_shape = [None, 1, 28, 28]
    X = tf.placeholder(tf.float32, shape=img_shape, name='X')
    # placeholder to avoid recomputation of adversarial images for critic
    X_hat_h = tf.placeholder(tf.float32, shape=img_shape, name='X_hat')
    y = tf.placeholder(tf.int32, shape=[None], name='y')
    y_onehot = tf.one_hot(y, num_classes)
    reduce_ind = list(range(1, X.get_shape().ndims))
    # test/validation inputs
    X_v = tf.placeholder(tf.float32, shape=img_shape, name='X_v')
    y_v = tf.placeholder(tf.int32, shape=[None], name='y_v')
    y_v_onehot = tf.one_hot(y_v, num_classes)

    # classifier model
    model = create_model(FLAGS, name=FLAGS.model_name)

    def test_model(x, **kwargs):
        return model(x, train=False, **kwargs)

    # generator
    def generator(inputs, confidence, targets=None):
        return high_confidence_attack_unrolled(
            lambda x: model(x)['logits'],
            inputs,
            targets=targets,
            confidence=confidence,
            max_iter=FLAGS.attack_iter,
            over_shoot=FLAGS.attack_overshoot,
            attack_random=FLAGS.attack_random,
            attack_uniform=FLAGS.attack_uniform,
            attack_label_smoothing=FLAGS.attack_label_smoothing)

    def test_generator(inputs, confidence, targets=None):
        return high_confidence_attack(lambda x: test_model(x)['logits'],
                                      inputs,
                                      targets=targets,
                                      confidence=confidence,
                                      max_iter=FLAGS.df_iter,
                                      over_shoot=FLAGS.df_overshoot,
                                      random=FLAGS.attack_random,
                                      uniform=FLAGS.attack_uniform,
                                      clip_dist=FLAGS.df_clip)

    # discriminator
    critic = create_model(FLAGS, prefix='critic_', name='critic')

    # classifier outputs
    outs_x = model(X)
    outs_x_v = test_model(X_v)
    params = tf.trainable_variables()
    model_weights = [param for param in params if "weights" in param.name]
    vars = tf.model_variables()
    target_conf_v = [None]

    if FLAGS.attack_confidence == "same":
        # set the target confidence to the confidence of the original prediction
        target_confidence = outs_x['conf']
        target_conf_v[0] = target_confidence
    elif FLAGS.attack_confidence == "class_running_mean":
        # set the target confidence to the mean confidence of the specific target
        # use running mean estimate
        class_conf_mean = tf.Variable(np.ones(num_classes, dtype=np.float32))
        batch_conf_mean = tf.unsorted_segment_mean(outs_x['conf'],
                                                   outs_x['pred'], num_classes)
        # if batch does not contain predictions for the specific target
        # (zeroes), replace zeroes with stored class mean (previous batch)
        batch_conf_mean = tf.where(tf.not_equal(batch_conf_mean, 0),
                                   batch_conf_mean, class_conf_mean)
        # update class confidence mean
        class_conf_mean = assign_moving_average(class_conf_mean,
                                                batch_conf_mean, 0.5)
        # init class confidence during pre-training
        tf.add_to_collection("PREINIT_OPS", class_conf_mean)

        def target_confidence(targets_onehot):
            targets = tf.argmax(targets_onehot, axis=1)
            check_conf = tf.Assert(
                tf.reduce_all(tf.not_equal(class_conf_mean, 0)),
                [class_conf_mean])
            with tf.control_dependencies([check_conf]):
                t = tf.gather(class_conf_mean, targets)
            target_conf_v[0] = t
            return tf.stop_gradient(t)
    else:
        target_confidence = float(FLAGS.attack_confidence)
        target_conf_v[0] = target_confidence

    X_hat = generator(X, target_confidence)
    outs_x_hat = model(X_hat)
    # select examples for which attack succeeded (changed the prediction)
    X_hat_filter = tf.not_equal(outs_x['pred'], outs_x_hat['pred'])
    X_hat_f = tf.boolean_mask(X_hat, X_hat_filter)
    X_f = tf.boolean_mask(X, X_hat_filter)

    outs_x_f = model(X_f)
    outs_x_hat_f = model(X_hat_f)
    X_hatd = tf.stop_gradient(X_hat)
    X_rec = generator(X_hatd, outs_x['conf'], outs_x['pred'])
    X_rec_f = tf.boolean_mask(X_rec, X_hat_filter)

    # validation/test adversarial examples
    X_v_hat = test_generator(X_v, FLAGS.val_attack_confidence)
    X_v_hatd = tf.stop_gradient(X_v_hat)
    X_v_rec = test_generator(X_v_hatd,
                             outs_x_v['conf'],
                             targets=outs_x_v['pred'])
    X_v_hat_df = deepfool(lambda x: test_model(x)['logits'],
                          X_v,
                          y_v,
                          max_iter=FLAGS.df_iter,
                          clip_dist=FLAGS.df_clip)
    X_v_hat_df_all = deepfool(lambda x: test_model(x)['logits'],
                              X_v,
                              max_iter=FLAGS.df_iter,
                              clip_dist=FLAGS.df_clip)

    y_hat = outs_x['pred']
    y_adv = outs_x_hat['pred']
    y_adv_f = outs_x_hat_f['pred']
    tf.summary.histogram('y_data', y, collections=["model_summaries"])
    tf.summary.histogram('y_hat', y_hat, collections=["model_summaries"])
    tf.summary.histogram('y_adv', y_adv, collections=["model_summaries"])

    # critic outputs
    critic_outs_x = critic(X)
    critic_outs_x_hat = critic(X_hat_f)
    critic_params = list(set(tf.trainable_variables()) - set(params))
    critic_vars = list(set(tf.trainable_variables()) - set(vars))

    # binary logits for a specific target
    logits_data = critic_outs_x['logits']
    logits_data_flt = tf.reshape(logits_data, (-1, ))
    z_data = tf.gather(logits_data_flt,
                       tf.range(tf.shape(X)[0]) * num_classes + y)
    logits_adv = critic_outs_x_hat['logits']
    logits_adv_flt = tf.reshape(logits_adv, (-1, ))
    z_adv = tf.gather(logits_adv_flt,
                      tf.range(tf.shape(X_hat_f)[0]) * num_classes + y_adv_f)

    # classifier/generator losses
    nll = tf.reduce_mean(
        tf.losses.softmax_cross_entropy(y_onehot, outs_x['logits']))
    nll_v = tf.reduce_mean(
        tf.losses.softmax_cross_entropy(y_v_onehot, outs_x_v['logits']))
    # gan losses
    gan = tf.losses.sigmoid_cross_entropy(tf.ones_like(z_adv), z_adv)
    rec_l1 = tf.reduce_mean(
        tf.reduce_sum(tf.abs(X_f - X_rec_f), axis=reduce_ind))
    rec_l2 = tf.reduce_mean(tf.reduce_sum((X_f - X_rec_f)**2, axis=reduce_ind))

    weight_decay = slim.apply_regularization(slim.l2_regularizer(1.0),
                                             model_weights[:-1])
    pretrain_loss = nll + 5e-6 * weight_decay
    loss = nll + FLAGS.lmbd * gan
    if FLAGS.lmbd_rec_l1 > 0:
        loss += FLAGS.lmbd_rec_l1 * rec_l1
    if FLAGS.lmbd_rec_l2 > 0:
        loss += FLAGS.lmbd_rec_l2 * rec_l2
    if FLAGS.weight_decay > 0:
        loss += FLAGS.weight_decay * weight_decay

    # critic loss
    critic_gan_data = tf.losses.sigmoid_cross_entropy(tf.ones_like(z_data),
                                                      z_data)
    # use placeholder for X_hat to avoid recomputation of adversarial noise
    y_adv_h = model(X_hat_h)['pred']
    logits_adv_h = critic(X_hat_h)['logits']
    logits_adv_flt_h = tf.reshape(logits_adv_h, (-1, ))
    z_adv_h = tf.gather(logits_adv_flt_h,
                        tf.range(tf.shape(X_hat_h)[0]) * num_classes + y_adv_h)
    critic_gan_adv = tf.losses.sigmoid_cross_entropy(tf.zeros_like(z_adv_h),
                                                     z_adv_h)
    critic_gan = critic_gan_data + critic_gan_adv

    # Gulrajani discriminator regularizer (we do not interpolate)
    critic_grad_data = tf.gradients(z_data, X)[0]
    critic_grad_adv = tf.gradients(z_adv_h, X_hat_h)[0]
    critic_grad_penalty = norm_penalty(critic_grad_adv) + norm_penalty(
        critic_grad_data)
    critic_loss = critic_gan + FLAGS.lmbd_grad * critic_grad_penalty

    # classifier model_metrics
    err = 1 - slim.metrics.accuracy(outs_x['pred'], y)
    conf = tf.reduce_mean(outs_x['conf'])
    err_hat = 1 - slim.metrics.accuracy(
        test_model(X_hat)['pred'], outs_x['pred'])
    err_hat_f = 1 - slim.metrics.accuracy(
        test_model(X_hat_f)['pred'], outs_x_f['pred'])
    err_rec = 1 - slim.metrics.accuracy(
        test_model(X_rec)['pred'], outs_x['pred'])
    conf_hat = tf.reduce_mean(test_model(X_hat)['conf'])
    conf_hat_f = tf.reduce_mean(test_model(X_hat_f)['conf'])
    conf_rec = tf.reduce_mean(test_model(X_rec)['conf'])
    err_v = 1 - slim.metrics.accuracy(outs_x_v['pred'], y_v)
    conf_v_hat = tf.reduce_mean(test_model(X_v_hat)['conf'])
    l2_hat = tf.sqrt(tf.reduce_sum((X_f - X_hat_f)**2, axis=reduce_ind))
    tf.summary.histogram('l2_hat', l2_hat, collections=["model_summaries"])

    # critic model_metrics
    critic_err_data = 1 - binary_accuracy(
        z_data, tf.ones(tf.shape(z_data), tf.bool), 0.0)
    critic_err_adv = 1 - binary_accuracy(
        z_adv, tf.zeros(tf.shape(z_adv), tf.bool), 0.0)

    # validation model_metrics
    err_df = 1 - slim.metrics.accuracy(test_model(X_v_hat_df)['pred'], y_v)
    err_df_all = 1 - slim.metrics.accuracy(
        test_model(X_v_hat_df_all)['pred'], outs_x_v['pred'])
    l2_v_hat = tf.sqrt(tf.reduce_sum((X_v - X_v_hat)**2, axis=reduce_ind))
    l2_v_rec = tf.sqrt(tf.reduce_sum((X_v - X_v_rec)**2, axis=reduce_ind))
    l1_v_rec = tf.reduce_sum(tf.abs(X_v - X_v_rec), axis=reduce_ind)
    l2_df = tf.sqrt(tf.reduce_sum((X_v - X_v_hat_df)**2, axis=reduce_ind))
    l2_df_norm = l2_df / tf.sqrt(tf.reduce_sum(X_v**2, axis=reduce_ind))
    l2_df_all = tf.sqrt(
        tf.reduce_sum((X_v - X_v_hat_df_all)**2, axis=reduce_ind))
    l2_df_norm_all = l2_df_all / tf.sqrt(tf.reduce_sum(X_v**2,
                                                       axis=reduce_ind))
    tf.summary.histogram('l2_df', l2_df, collections=["adv_summaries"])
    tf.summary.histogram('l2_df_norm',
                         l2_df_norm,
                         collections=["adv_summaries"])

    # model_metrics
    pretrain_model_metrics = OrderedDict([('nll', nll),
                                          ('weight_decay', weight_decay),
                                          ('err', err)])
    model_metrics = OrderedDict([('loss', loss), ('nll', nll),
                                 ('l2_hat', tf.reduce_mean(l2_hat)),
                                 ('gan', gan), ('rec_l1', rec_l1),
                                 ('rec_l2', rec_l2),
                                 ('weight_decay', weight_decay), ('err', err),
                                 ('conf', conf), ('err_hat', err_hat),
                                 ('err_hat_f', err_hat_f),
                                 ('conf_t', tf.reduce_mean(target_conf_v[0])),
                                 ('conf_hat', conf_hat),
                                 ('conf_hat_f', conf_hat_f),
                                 ('err_rec', err_rec), ('conf_rec', conf_rec)])
    critic_metrics = OrderedDict([('c_loss', critic_loss),
                                  ('c_gan', critic_gan),
                                  ('c_gan_data', critic_gan_data),
                                  ('c_gan_adv', critic_gan_adv),
                                  ('c_grad_norm', critic_grad_penalty),
                                  ('c_err_adv', critic_err_adv),
                                  ('c_err_data', critic_err_data)])
    val_metrics = OrderedDict([('nll', nll_v), ('err', err_v)])
    adv_metrics = OrderedDict([('l2_df', tf.reduce_mean(l2_df)),
                               ('l2_df_norm', tf.reduce_mean(l2_df_norm)),
                               ('l2_df_all', tf.reduce_mean(l2_df_all)),
                               ('l2_df_all_norm',
                                tf.reduce_mean(l2_df_norm_all)),
                               ('l2_hat', tf.reduce_mean(l2_v_hat)),
                               ('conf_hat', conf_v_hat),
                               ('l1_rec', tf.reduce_mean(l1_v_rec)),
                               ('l2_rec', tf.reduce_mean(l2_v_rec)),
                               ('err_df', err_df), ('err_df_all', err_df_all)])

    pretrain_metric_mean, pretrain_metric_upd = register_metrics(
        pretrain_model_metrics, collections="pretrain_model_summaries")
    metric_mean, metric_upd = register_metrics(model_metrics,
                                               collections="model_summaries")
    critic_metric_mean, critic_metric_upd = register_metrics(
        critic_metrics, collections="critic_summaries")
    val_metric_mean, val_metric_upd = register_metrics(
        val_metrics, prefix="val_", collections="val_summaries")
    adv_metric_mean, adv_metric_upd = register_metrics(
        adv_metrics, collections="adv_summaries")
    metrics_reset = tf.variables_initializer(tf.local_variables())

    # training ops
    lr = tf.Variable(FLAGS.lr, trainable=False)
    critic_lr = tf.Variable(FLAGS.critic_lr, trainable=False)
    tf.summary.scalar('lr', lr, collections=["model_summaries"])
    tf.summary.scalar('critic_lr', critic_lr, collections=["critic_summaries"])

    optimizer = tf.train.AdamOptimizer(learning_rate=lr, beta1=0.5)

    preinit_ops = tf.get_collection("PREINIT_OPS")
    with tf.control_dependencies(preinit_ops):
        pretrain_solver = optimizer.minimize(pretrain_loss, var_list=params)
    solver = optimizer.minimize(loss, var_list=params)
    critic_solver = (tf.train.AdamOptimizer(
        learning_rate=critic_lr, beta1=0.5).minimize(critic_loss,
                                                     var_list=critic_params))

    # train
    summary_images, summary_labels = select_balanced_subset(
        train_ds.images, train_ds.labels, num_classes, num_classes)
    summary_images = summary_images.transpose((0, 3, 1, 2))
    save_path = os.path.join(FLAGS.samples_dir, 'orig.png')
    save_images(summary_images, save_path)

    if FLAGS.gpu_memory < 1.0:
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=FLAGS.gpu_memory)
        config = tf.ConfigProto(gpu_options=gpu_options)
    else:
        config = None
    with tf.Session(config=config) as sess:
        try:
            # summaries
            summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
            summaries = tf.summary.merge_all("model_summaries")
            critic_summaries = tf.summary.merge_all("critic_summaries")
            val_summaries = tf.summary.merge_all("val_summaries")
            adv_summaries = tf.summary.merge_all("adv_summaries")

            # initialization
            tf.local_variables_initializer().run()
            tf.global_variables_initializer().run()

            # pretrain model
            if FLAGS.pretrain_niter > 0:
                logging.info("Model pretraining")
                for epoch in range(1, FLAGS.pretrain_niter + 1):
                    train_iterator = batch_iterator(train_ds.images,
                                                    train_ds.labels,
                                                    FLAGS.batch_size,
                                                    shuffle=True)
                    sess.run(metrics_reset)

                    start_time = time.time()
                    for ind, (images, labels) in enumerate(train_iterator):
                        sess.run([pretrain_solver, pretrain_metric_upd],
                                 feed_dict={
                                     X: images,
                                     y: labels
                                 })

                    str_bfr = six.StringIO()
                    str_bfr.write("Pretrain epoch [{}, {:.2f}s]:".format(
                        epoch,
                        time.time() - start_time))
                    print_results_str(str_bfr, pretrain_model_metrics.keys(),
                                      sess.run(pretrain_metric_mean))
                    print_results_str(str_bfr, critic_metrics.keys(),
                                      sess.run(critic_metric_mean))
                    logging.info(str_bfr.getvalue()[:-1])

            # training
            for epoch in range(1, FLAGS.niter + 1):
                train_iterator = batch_iterator(train_ds.images,
                                                train_ds.labels,
                                                FLAGS.batch_size,
                                                shuffle=True)
                sess.run(metrics_reset)

                start_time = time.time()
                for ind, (images, labels) in enumerate(train_iterator):
                    batch_index = (epoch - 1) * (train_ds.images.shape[0] //
                                                 FLAGS.batch_size) + ind
                    # train critic for several steps
                    X_hat_np = sess.run(X_hat, feed_dict={X: images})
                    for _ in range(FLAGS.critic_steps - 1):
                        sess.run([critic_solver],
                                 feed_dict={
                                     X: images,
                                     y: labels,
                                     X_hat_h: X_hat_np
                                 })
                    else:
                        summary = sess.run([
                            critic_solver, critic_metric_upd, critic_summaries
                        ],
                                           feed_dict={
                                               X: images,
                                               y: labels,
                                               X_hat_h: X_hat_np
                                           })[-1]
                        summary_writer.add_summary(summary, batch_index)
                    # train model
                    summary = sess.run([solver, metric_upd, summaries],
                                       feed_dict={
                                           X: images,
                                           y: labels
                                       })[-1]
                    summary_writer.add_summary(summary, batch_index)

                str_bfr = six.StringIO()
                str_bfr.write("Train epoch [{}, {:.2f}s]:".format(
                    epoch,
                    time.time() - start_time))
                print_results_str(str_bfr, model_metrics.keys(),
                                  sess.run(metric_mean))
                print_results_str(str_bfr, critic_metrics.keys(),
                                  sess.run(critic_metric_mean))
                logging.info(str_bfr.getvalue()[:-1])

                val_iterator = batch_iterator(val_ds.images,
                                              val_ds.labels,
                                              100,
                                              shuffle=False)
                for images, labels in val_iterator:
                    summary = sess.run([val_metric_upd, val_summaries],
                                       feed_dict={
                                           X_v: images,
                                           y_v: labels
                                       })[-1]
                    summary_writer.add_summary(summary, epoch)
                str_bfr = six.StringIO()
                str_bfr.write("Valid epoch [{}]:".format(epoch))
                print_results_str(str_bfr, val_metrics.keys(),
                                  sess.run(val_metric_mean))
                logging.info(str_bfr.getvalue()[:-1])

                # learning rate decay
                update_lr = lr_decay(lr, epoch)
                if update_lr is not None:
                    sess.run(update_lr)
                    logging.debug(
                        "learning rate was updated to: {:.10f}".format(
                            lr.eval()))
                critic_update_lr = lr_decay(critic_lr, epoch, prefix='critic_')
                if critic_update_lr is not None:
                    sess.run(critic_update_lr)
                    logging.debug(
                        "critic learning rate was updated to: {:.10f}".format(
                            critic_lr.eval()))

                if epoch % FLAGS.summary_frequency == 0:
                    samples_hat, samples_rec, samples_df, summary = sess.run(
                        [
                            X_v_hat, X_v_rec, X_v_hat_df, adv_summaries,
                            adv_metric_upd
                        ],
                        feed_dict={
                            X_v: summary_images,
                            y_v: summary_labels
                        })[:-1]
                    summary_writer.add_summary(summary, epoch)
                    save_path = os.path.join(FLAGS.samples_dir,
                                             'epoch_orig-%d.png' % epoch)
                    save_images(summary_images, save_path)
                    save_path = os.path.join(FLAGS.samples_dir,
                                             'epoch-%d.png' % epoch)
                    save_images(samples_hat, save_path)
                    save_path = os.path.join(FLAGS.samples_dir,
                                             'epoch_rec-%d.png' % epoch)
                    save_images(samples_rec, save_path)
                    save_path = os.path.join(FLAGS.samples_dir,
                                             'epoch_df-%d.png' % epoch)
                    save_images(samples_df, save_path)

                    str_bfr = six.StringIO()
                    str_bfr.write("Summary epoch [{}]:".format(epoch))
                    print_results_str(str_bfr, adv_metrics.keys(),
                                      sess.run(adv_metric_mean))
                    logging.info(str_bfr.getvalue()[:-1])

                if FLAGS.checkpoint_frequency != -1 and epoch % FLAGS.checkpoint_frequency == 0:
                    save_checkpoint(sess, vars, epoch=epoch)
                    save_checkpoint(sess,
                                    critic_vars,
                                    name="critic_model",
                                    epoch=epoch)
        except KeyboardInterrupt:
            logging.debug("Keyboard interrupt. Stopping training...")
        except NanError as e:
            logging.info(e)
        finally:
            sess.run(metrics_reset)
            save_checkpoint(sess, vars)
            save_checkpoint(sess, critic_vars, name="critic_model")

        # final accuracy
        test_iterator = batch_iterator(test_ds.images,
                                       test_ds.labels,
                                       100,
                                       shuffle=False)
        for images, labels in test_iterator:
            sess.run([val_metric_upd], feed_dict={X_v: images, y_v: labels})
        str_bfr = six.StringIO()
        str_bfr.write("Final epoch [{}]:".format(epoch))
        for metric_name, metric_value in zip(val_metrics.keys(),
                                             sess.run(val_metric_mean)):
            str_bfr.write(" {}: {:.6f},".format(metric_name, metric_value))
        logging.info(str_bfr.getvalue()[:-1])
Exemplo n.º 6
0
def save_embeddings_to_file(model_manager, encoding_batch_size=20):

    state = model_manager.load_current_state()

    utt_features = state['qdim_encoder']
    dia_features = state['sdim']

    encoder = model_manager.load_currently_selected_model()

    database = get_database(model_manager)

    dset = database[BINARIZED_SET_NAME]

    logging.debug('loading binarized dialogues into memory and storing index')
    binarized = [(d_idx, indices) for d_idx, indices in enumerate(dset)]

    logging.debug(
        '...creating and storing mapping from (dialogue id and utterance id) to (embedding storage location)'
    )

    del_if_exists(database, EMBEDDINGS_COORDINATES_SET_NAME)
    coords_set = database.create_dataset(EMBEDDINGS_COORDINATES_SET_NAME,
                                         (len(binarized), 2),
                                         dtype='i4')

    enlargened_idx = 0

    progress = 0
    start_time = time()

    for d_idx, indices in binarized:
        progress += 1

        conv_length = num_turns(indices, encoder.eos_sym)
        coords = (enlargened_idx, conv_length)

        coords_set[d_idx] = coords

        enlargened_idx += conv_length
        if progress % 100 == 0:
            print_progress_bar(
                progress,
                len(binarized),
                additional_text='coords for %i dialogues created' % progress,
                start_time=start_time)

    num_embeddings = enlargened_idx

    logging.debug('...sorting based on dialouge length for quicker encoding')
    dialouge_ids, binarized = zip(
        *sorted(binarized, key=lambda tuple: len(tuple[1]), reverse=True))

    collect()

    logging.debug('creating files that hold embeddings')
    utt_file = FileArray(model_manager.folders['embeddings'] +
                         'utterance.embeddings.bin',
                         shape=(num_embeddings, utt_features),
                         dtype='f4')

    dia_file = FileArray(model_manager.folders['embeddings'] +
                         'dialogue.embeddings.bin',
                         shape=(num_embeddings, dia_features),
                         dtype='f4')

    utt_file.open()
    dia_file.open()

    progress = 0
    batches_to_process = len(binarized) / encoding_batch_size
    start_time = time()

    total_embeddings_encoded = 0

    for d_indices, batch in zip(
            batch_iterator(dialouge_ids, batch_size=encoding_batch_size),
            batch_iterator(
                binarized,
                batch_size=encoding_batch_size,
                apply_on_element=lambda indices: np.append(
                    [encoder.eos_sym], np.append(indices, encoder.eos_sym)))):
        progress += 1

        embeddings = encode_batch_to_embeddings(encoder, batch)

        coords = [coords_set[d_idx] for d_idx in d_indices]

        for embs, coord in zip(embeddings, coords):
            utt_embs = embs[0]
            dia_embs = embs[1]

            assert len(utt_embs) == coord[1]
            assert len(dia_embs) == coord[1]

            assert len(utt_embs[0]) == utt_features
            assert len(dia_embs[0]) == dia_features

            #for local_idx, global_idx in enumerate(xrange(coord[0], coord[0] + coord[1])):
            #    utt_file.write(global_idx, utt_embs[local_idx])
            #    dia_file.write(global_idx,  dia_embs[local_idx])
            #    total_embeddings_encoded += 1
            utt_file.write_chunk(coord[0], utt_embs)
            dia_file.write_chunk(coord[0], dia_embs)
            total_embeddings_encoded += coord[1]

        if progress % 1000 == 0:
            collect()

        # print '%i of %i batches encoded (%.3f%%)'%(progress, batches_to_process, 100*(float(progress)/float(batches_to_process)))
        print_progress_bar(
            progress,
            batches_to_process,
            additional_text=
            '%i batches of dialogues processed (total of %i dialogues) (total of %i embeddings) (%i conv length)'
            % (progress, ((progress - 1) * encoding_batch_size) + len(batch),
               total_embeddings_encoded, len(batch[0])),
            start_time=start_time)

    utt_file.close()
    dia_file.close()
Exemplo n.º 7
0
	def fasta2csv(self, is_local_interpro):
		"""
		Convert fasta file to csv

		Parameters
		----------
		is_local_interpro : bool
			the input fasta file is created by running local Interproscan (True), otherwise (False)

		Returns
		-------

		"""
		print("Creating row for each protein with domain, please wait..")
		dataset_name = "toxin_dataset.csv"
		num_all_proteins = 0
		num_proteins_with_domains = 0
		num_remain_proteins = 0
		csv_already_exists = True
		if not isfile(join(self.output_path, dataset_name)):  # if csv not exists then firstly write header
			csv_already_exists = False
		for fasta_file in listdir(self.fasta_dir_path):
			short_label = splitext(basename(fasta_file))[0].split(".")[0]
			with open(join(self.fasta_dir_path, fasta_file), 'r') as fasta_data, open(self.domains_path,
			                                                                          'r') as domains_data, open(
					join(self.output_path, dataset_name), 'a') as dataset_csv, open(
					join(self.output_path, "targetp_remaining_seq" + "_" + short_label + ".fasta"),
					'a') as remain_seqs_file:
				proteins_dict = SeqIO.to_dict(SeqIO.parse(fasta_data, "fasta"))
				num_all_proteins += len(proteins_dict)
				uniprot2prot = self.extract_uniprot4protein_keys(proteins_dict)
				writer = csv.writer(dataset_csv, delimiter=',')
				if not csv_already_exists:  # if csv not exists then firstly write header
					proteins_domains_header = ["uniprot_id", "toxin", "seq", "seq_len", "interpro_domains",
					                           "evidence_db_domains"]
					writer.writerow(proteins_domains_header)
					csv_already_exists = True
				batch_num_lines = 10000

				for i, batch in enumerate(batch_iterator(domains_data, batch_num_lines)):
					for line in batch:
						line_split = line.strip().split("\t")
						assert len(line_split) == 3, "AssertionError: {} does not have 3 tabs.".format(line)
						uniprot_id = line_split[0]
						if uniprot_id == "uniprot_id":
							print("Skipping first line")
							continue
						if is_local_interpro:
							uniprot_id = uniprot_id.split("|")[1]
						if uniprot_id in uniprot2prot:
							interpro_ids = line_split[1]
							evidence_db_ids = line_split[2]
							label = self.get_labels(fasta_file)
							# make the row of the current protein
							protein_row = [uniprot_id, label, str(uniprot2prot[uniprot_id].seq),
							               len(str(uniprot2prot[uniprot_id].seq)), interpro_ids, evidence_db_ids]
							writer.writerow(protein_row)
							num_proteins_with_domains += 1
							# remove found protein from the dictionary, to keep track of the remaining proteins
							uniprot2prot.pop(uniprot_id)

				num_remain_proteins += len(uniprot2prot)  # update num of remain proteins
				SeqIO.write(uniprot2prot.values(), remain_seqs_file, "fasta")  # append remaining proteins to fasta
				print("num of remaining proteins for {} label: {} saved on remaining fasta".format(
					self.get_labels(fasta_file), len(uniprot2prot)))
		assert num_all_proteins == num_proteins_with_domains + num_remain_proteins, "AssertionError: total num of proteins should be equal to proteins with domains + proteins without domains."
		print("num of Toxin proteins: {}".format(num_all_proteins))
		print("num of Toxin proteins with found domains: {}".format(num_proteins_with_domains))
		print("num of remaining proteins with not found domains: {}".format(num_remain_proteins))
Exemplo n.º 8
0
    def fasta2csv(self, value2remove):
        """
		Convert fasta file to csv

		Parameters
		----------
		self : object
			DeepLocExperiment object setup for this analysis
		value2remove: str
			if "U" remove proteins with unknown membrane label assignment

		Returns
		-------
		str
			full path of the created csv
		"""
        print("Creating row for each protein with domains, please wait..")
        dataset_name = "deeploc_dataset_" + self.label_name + ".csv"
        with open(self.fasta_path, 'r') as fasta_data, open(
                self.domains_path, 'r') as domains_data, open(
                    os.path.join(self.output_path, dataset_name),
                    'w') as dataset_csv, open(
                        os.path.join(self.output_path,
                                     "deeploc_remaining_seq.fasta"),
                        'w') as remain_seqs_file:
            proteins_dict = SeqIO.to_dict(SeqIO.parse(fasta_data, "fasta"))
            num_all_proteins = len(proteins_dict)
            proteins_domains_header = [
                "uniprot_id", "train_test", "cellular_location",
                "membrane_soluble", "seq", "seq_len", "interpro_domains",
                "evidence_db_domains"
            ]
            writer = csv.writer(dataset_csv, delimiter=',')
            writer.writerow(proteins_domains_header)
            batch_num_lines = 10000
            num_proteins_with_domains = 0
            for i, batch in enumerate(
                    batch_iterator(domains_data, batch_num_lines)):
                for line in batch:
                    line_split = line.strip().split("\t")
                    assert len(
                        line_split
                    ) == 3, "AssertionError: {} does not have 3 tabs.".format(
                        line)
                    uniprot_id = line_split[0]
                    if uniprot_id in proteins_dict:
                        print("Writing row for {}".format(uniprot_id))
                        interpro_ids = line_split[1]
                        evidence_db_ids = line_split[2]
                        labels = self.get_labels(
                            proteins_dict[uniprot_id].description)
                        # make the row of current protein
                        protein_row = [
                            uniprot_id, labels.train, labels.loc, labels.sol,
                            str(proteins_dict[uniprot_id].seq),
                            len(str(proteins_dict[uniprot_id].seq)),
                            interpro_ids, evidence_db_ids
                        ]
                        if value2remove != "":
                            if labels.sol == value2remove:
                                print(
                                    "Skipping protein {} having membrane_soluble as {}"
                                    .format(uniprot_id, labels.sol))
                            else:
                                writer.writerow(protein_row)
                        else:
                            writer.writerow(protein_row)
                        num_proteins_with_domains = num_proteins_with_domains + 1
                        proteins_dict.pop(
                            uniprot_id
                        )  # remove found protein from the dictionary, to keep track of the remaining proteins

            SeqIO.write(proteins_dict.values(), remain_seqs_file, "fasta")
        print("num of DeepLoc proteins: {}".format(num_all_proteins))
        print("num of DeepLoc proteins with found domains: {}".format(
            num_proteins_with_domains))
        print("num of remaining proteins with not found domains: {}".format(
            len(proteins_dict)))
        return os.path.join(self.output_path, dataset_name)
Exemplo n.º 9
0
    def fasta2csv(self):
        """
		Convert a directory of fasta files to data csv

		Parameters
		----------

		Returns
		-------
		None
		"""
        print("Creating row for each protein with domain, please wait..")
        dataset_name = "targetp_dataset.csv"
        num_all_proteins = 0
        num_proteins_with_domains = 0
        num_remain_proteins = 0
        csv_already_exists = True
        if not isfile(join(
                self.output_path,
                dataset_name)):  # if csv not exists then firstly write header
            csv_already_exists = False
        for fasta_file in listdir(self.fasta_dir_path):
            short_label = splitext(basename(fasta_file))[0].split(".")[0]
            with open(join(self.fasta_dir_path, fasta_file),
                      'r') as fasta_data, open(
                          self.domains_path, 'r') as domains_data, open(
                              join(self.output_path, dataset_name),
                              'a') as dataset_csv, open(
                                  join(
                                      self.output_path, short_label + "." +
                                      "targetp_remaining_seq.fasta"),
                                  'a') as remain_seqs_file:
                proteins_dict = SeqIO.to_dict(SeqIO.parse(fasta_data, "fasta"))
                num_all_proteins += len(proteins_dict)
                writer = csv.writer(dataset_csv, delimiter=',')
                if not csv_already_exists:  #if csv not exists then firstly write header
                    proteins_domains_header = [
                        "uniprot_id", "cellular_location", "seq", "seq_len",
                        "interpro_domains", "evidence_db_domains"
                    ]
                    writer.writerow(proteins_domains_header)
                    csv_already_exists = True

                batch_num_lines = 10000
                for i, batch in enumerate(
                        batch_iterator(domains_data, batch_num_lines)):
                    for line in batch:
                        line_split = line.strip().split("\t")
                        assert len(
                            line_split
                        ) == 3, "AssertionError: {} does not have 3 tabs.".format(
                            line)
                        uniprot_id = line_split[0]
                        if uniprot_id in proteins_dict:
                            interpro_ids = line_split[1]
                            evidence_db_ids = line_split[2]
                            label = self.get_labels(fasta_file)
                            # make the row of the current protein
                            protein_row = [
                                uniprot_id, label,
                                str(proteins_dict[uniprot_id].seq),
                                len(str(proteins_dict[uniprot_id].seq)),
                                interpro_ids, evidence_db_ids
                            ]
                            writer.writerow(protein_row)
                            num_proteins_with_domains += 1
                            # remove found protein from the dictionary, to keep track of the remaining proteins
                            proteins_dict.pop(uniprot_id)

                num_remain_proteins += len(
                    proteins_dict)  # update num of remain proteins
                SeqIO.write(proteins_dict.values(), remain_seqs_file,
                            "fasta")  # append remaining proteins to fasta
                print(
                    "num of remaining proteins for {} label: {} saved on remaining fasta"
                    .format(self.get_labels(fasta_file), len(proteins_dict)))

        ### processed proteins stats ###
        assert num_all_proteins == num_proteins_with_domains + num_remain_proteins, "AssertionError: total num of proteins should be equal to proteins with domains + proteins without domains."
        print("num of TargetP proteins: {}".format(num_all_proteins))
        print("num of TargetP proteins with found domains: {}".format(
            num_proteins_with_domains))
        print("num of remaining proteins with not found domains: {}".format(
            num_remain_proteins))
Exemplo n.º 10
0
def main(unused_args):
    assert len(unused_args) == 1, unused_args
    setup_experiment()

    mnist_ds = mnist.read_data_sets(FLAGS.data_dir,
                                    dtype=tf.float32,
                                    reshape=False,
                                    validation_size=FLAGS.validation_size)
    test_ds = getattr(mnist_ds, FLAGS.dataset)

    test_images, test_labels = test_ds.images, test_ds.labels
    if FLAGS.sort_labels:
        ys_indices = np.argsort(test_labels)
        test_images = test_images[ys_indices]
        test_labels = test_labels[ys_indices]

    img_shape = [None, 1, 28, 28]
    X = tf.placeholder(tf.float32, shape=img_shape, name='X')
    y = tf.placeholder(tf.int32, shape=[None])
    y_onehot = tf.one_hot(y, FLAGS.num_classes)

    # model
    model = create_model(FLAGS, name=FLAGS.model_name)

    def test_model(x, **kwargs):
        return model(x, train=False, **kwargs)

    out_x = test_model(X)
    attack_clip = FLAGS.attack_clip if FLAGS.attack_clip > 0 else None
    if FLAGS.attack_box_clip:
        boxmin, boxmax = 0.0, 1.0
    else:
        boxmin, boxmax = None, None
    X_df = deepfool(lambda x: test_model(x)['logits'],
                    X,
                    labels=y,
                    max_iter=FLAGS.attack_iter,
                    clip_dist=attack_clip,
                    over_shoot=FLAGS.attack_overshoot,
                    boxmin=boxmin,
                    boxmax=boxmax)
    X_df_all = deepfool(lambda x: test_model(x)['logits'],
                        X,
                        max_iter=FLAGS.attack_iter,
                        clip_dist=attack_clip,
                        over_shoot=FLAGS.attack_overshoot,
                        boxmin=boxmin,
                        boxmax=boxmax)
    if FLAGS.hc_confidence == 'same':
        confidence = out_x['conf']
    else:
        confidence = float(FLAGS.hc_confidence)
    X_hc = high_confidence_attack(lambda x: test_model(x)['logits'],
                                  X,
                                  labels=y,
                                  random=FLAGS.hc_random,
                                  max_iter=FLAGS.attack_iter,
                                  clip_dist=attack_clip,
                                  confidence=confidence,
                                  boxmin=boxmin,
                                  boxmax=boxmax)
    X_hcd = tf.stop_gradient(X_hc)
    X_rec = high_confidence_attack(lambda x: model(x)['logits'],
                                   X_hcd,
                                   targets=out_x['pred'],
                                   attack_topk=None,
                                   max_iter=FLAGS.attack_iter,
                                   clip_dist=attack_clip,
                                   confidence=out_x['conf'],
                                   boxmin=boxmin,
                                   boxmax=boxmax)

    out_x_df = test_model(X_df)
    out_x_hc = test_model(X_hc)

    reduce_ind = (1, 2, 3)
    X_norm = tf.sqrt(tf.reduce_sum(X**2, axis=reduce_ind))
    l2_df = tf.sqrt(tf.reduce_sum((X_df - X)**2, axis=reduce_ind))
    l2_df_norm = l2_df / X_norm
    smoothness_df = tf.reduce_mean(tf.image.total_variation(X_df))
    l2_df_all = tf.sqrt(tf.reduce_sum((X_df_all - X)**2, axis=reduce_ind))
    l2_df_all_norm = l2_df_all / X_norm
    l2_hc = tf.sqrt(tf.reduce_sum((X_hc - X)**2, axis=reduce_ind))
    l2_hc_norm = l2_hc / X_norm
    smoothness_hc = tf.reduce_mean(tf.image.total_variation(X_hc))
    l1_rec = tf.reduce_sum(tf.abs(X - X_rec), axis=reduce_ind)
    l2_rec = tf.sqrt(tf.reduce_sum((X - X_rec)**2, axis=reduce_ind))
    # image noise statistics
    psnr = tf.py_func(batch_compute_psnr, [X, X_df], tf.float32)
    ssim = tf.py_func(batch_compute_ssim, [X, X_df], tf.float32)

    nll = tf.reduce_mean(
        tf.losses.softmax_cross_entropy(y_onehot, out_x['logits']))
    err = 1 - slim.metrics.accuracy(out_x['pred'], y)
    conf = tf.reduce_mean(out_x['conf'])
    err_df = 1 - slim.metrics.accuracy(out_x_df['pred'], y)
    conf_df = tf.reduce_mean(out_x_df['conf'])
    err_hc = 1 - slim.metrics.accuracy(out_x_hc['pred'], y)
    conf_hc = tf.reduce_mean(out_x_hc['conf'])

    metrics = OrderedDict([('nll', nll), ('err', err), ('conf', conf),
                           ('err_df', err_df), ('err_hc', err_hc),
                           ('l2_df', tf.reduce_mean(l2_df)),
                           ('l2_df_norm', tf.reduce_mean(l2_df_norm)),
                           ('l2_df_all', tf.reduce_mean(l2_df_all)),
                           ('l2_df_all_norm', tf.reduce_mean(l2_df_all_norm)),
                           ('conf_df', conf_df),
                           ('smoothness_df', smoothness_df),
                           ('l2_hc', tf.reduce_mean(l2_hc)),
                           ('l2_hc_norm', tf.reduce_mean(l2_hc_norm)),
                           ('conf_hc', conf_hc),
                           ('smoothness_hc', smoothness_hc),
                           ('l1_rec', tf.reduce_mean(l1_rec)),
                           ('l2_rec', tf.reduce_mean(l2_rec)),
                           ('psnr', tf.reduce_mean(psnr)),
                           ('ssim', tf.reduce_mean(ssim))])
    metrics_mean, metrics_upd = register_metrics(metrics)

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.45)
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        tf.local_variables_initializer().run()
        model_loader = tf.train.Saver(tf.model_variables())
        model_filename = ('model' if FLAGS.restore_epoch_index is None else
                          'model-%d' % FLAGS.restore_epoch_index)
        model_path = os.path.join(FLAGS.load_dir, 'chks', model_filename)
        model_loader.restore(sess, model_path)

        summary_writer = tf.summary.FileWriter(FLAGS.working_dir, sess.graph)
        summaries = tf.summary.merge_all()

        test_iterator = batch_iterator(test_images,
                                       test_labels,
                                       FLAGS.batch_size,
                                       shuffle=False)
        start_time = time.time()
        for batch_index, (images, labels) in enumerate(test_iterator, 1):
            if batch_index % FLAGS.summary_frequency == 0:
                hc_images, df_images, rec_images, summary = sess.run(
                    [X_hc, X_df, X_rec, summaries, metrics_upd],
                    feed_dict={
                        X: images,
                        y: labels
                    })[:-1]
                save_path = os.path.join(FLAGS.samples_dir,
                                         'epoch_orig-%d.png' % batch_index)
                save_images(images, save_path)
                save_path = os.path.join(FLAGS.samples_dir,
                                         'epoch_hc-%d.png' % batch_index)
                save_images(hc_images, save_path)
                save_path = os.path.join(FLAGS.samples_dir,
                                         'epoch_df-%d.png' % batch_index)
                save_images(df_images, save_path)
                save_path = os.path.join(FLAGS.samples_dir,
                                         'epoch_rec-%d.png' % batch_index)
                save_images(rec_images, save_path)
            else:
                summary = sess.run([metrics_upd, summaries],
                                   feed_dict={
                                       X: images,
                                       y: labels
                                   })[-1]
            summary_writer.add_summary(summary, batch_index)
        str_bfr = six.StringIO()
        str_bfr.write("Test results [{:.2f}s]:".format(time.time() -
                                                       start_time))
        print_results_str(str_bfr,
                          metrics.keys(),
                          sess.run(metrics_mean),
                          throw_on_nan=False)
        logging.info(str_bfr.getvalue()[:-1])
Exemplo n.º 11
0
    def parse_prot2in(self, file_in_name, batch_num_lines, batch_num_prot):
        """
		Parse protein domain hits to create tabular formatted file relating each protein to its domains

		Parameters
		----------
		file_in_name : str
			input file name
		batch_num_lines : int
			number of lines to be parsed per batch
		batch_num_prot : int
			number of proteins to be processed per batch

		Returns
		-------
		None
		"""
        file_out_name = self.create_file_out_name()
        total_out_prot = 0
        if self.prot_len_file_name != "":
            prot_file = open(
                os.path.join(self.data_path, self.prot_len_file_name), 'r')
        else:
            prot_file = ""

        # check if output tabular file already exists, if yes then don't add header
        output_exists_already = False
        if os.path.isfile(os.path.join(self.data_path, file_out_name)):
            output_exists_already = True

        with gzip.open(os.path.join(self.data_path, file_in_name),
                       'rt') as file_in, open(
                           os.path.join(self.data_path, file_out_name),
                           'a') as file_out:
            if not output_exists_already:
                # write the header of the output file
                file_out.write("uniprot_id\tinterpro_ids\tevidence_db_ids\n")
            line_count = 0
            for i, batch in enumerate(batch_iterator(file_in,
                                                     batch_num_lines)):
                for hit_line in batch:
                    hit_line = hit_line.strip()
                    hit_tabs = hit_line.split("\t")
                    if self.interpro_local_format:
                        assert len(
                            hit_tabs
                        ) >= 11, "AssertionError: " + hit_line + "has less than 11 tabs."
                    else:
                        assert len(
                            hit_tabs
                        ) == 6, "AssertionError: " + hit_line + " has more than 6 tabs."
                    if self.last_protein.uniprot_id == "":
                        # initialize protein list
                        protein = Protein(self.with_overlap,
                                          self.with_redundant, self.with_gap,
                                          hit_line, prot_file,
                                          self.interpro_local_format)
                        self.last_protein = protein
                        self.proteins.append(protein)
                    else:
                        if Protein.get_prot_id(
                                hit_line) == self.last_protein.uniprot_id:
                            # update last created protein
                            self.last_protein.add_domain(hit_line)
                        else:
                            # write to file complete proteins
                            if len(self.proteins) == batch_num_prot:
                                self.update_output(file_out)
                                total_out_prot = total_out_prot + len(
                                    self.proteins)
                                self.update_no_intepro()
                                del self.proteins[:]
                            # create new protein and append it to proteins
                            protein = Protein(self.with_overlap,
                                              self.with_redundant,
                                              self.with_gap, hit_line,
                                              prot_file,
                                              self.interpro_local_format)
                            self.last_protein = protein
                            self.proteins.append(protein)
                    line_count = line_count + 1
                # save last proteins
                self.update_output(file_out)
                total_out_prot = total_out_prot + len(self.proteins)
                self.update_no_intepro()
                del self.proteins[:]
        if self.prot_len_file_name != "":
            prot_file.close()
        print("Successfully parsed {} lines.".format(line_count))
        print("Successfully created {} proteins.".format(total_out_prot))
        print("Number of proteins without any interpro annotation: {}.".format(
            self.num_prot_with_no_interpro))
Exemplo n.º 12
0
    def fasta2csv(self, fasta_name):
        """
		Convert fasta file to data csv

		Parameters
		----------
		fasta_name : str
			fasta file name

		Returns
		-------
		None
		"""
        print("Creating row for each protein with domain, please wait..")
        dataset_name = "new_dataset.csv"
        num_all_proteins = 0
        num_proteins_with_domains = 0
        num_remain_proteins = 0
        csv_already_exists = True
        if not isfile(
                join(self.output_path,
                     dataset_name)):  # if csv does not exist write header
            csv_already_exists = False
        with open(join(self.input_path, fasta_name), 'r') as fasta_data, open(self.domains_path, 'r') as domains_data, \
          open(join(self.output_path, dataset_name), 'a') as dataset_csv, \
          open(join(self.output_path, "new_remaining_seq.fasta"), 'w') as remaining_seq_file:
            proteins_dict = SeqIO.to_dict(SeqIO.parse(fasta_data, "fasta"))
            num_all_proteins = len(proteins_dict)
            writer = csv.writer(dataset_csv, delimiter=',')
            proteins_domains_header = [
                "id", "ec", "seq", "seq_len", "interpro_domains",
                "evidence_db_domains"
            ]
            if not csv_already_exists:
                writer.writerow(proteins_domains_header)
                csv_already_exists = True
            batch_num_lines = 10000

            for i, batch in enumerate(
                    batch_iterator(domains_data, batch_num_lines)):
                for line in batch:
                    line_split = line.strip().split("\t")
                    assert len(
                        line_split
                    ) == 3, "AssertionError: {} does not have 3 tabs.".format(
                        line)
                    prot_id = line_split[0]
                    if prot_id == "uniprot_id":
                        print("Skipping first line")
                        continue
                    else:
                        if prot_id in proteins_dict:
                            # print("Writing row for prot id {}".format(prot_id))
                            interpro_ids = line_split[1]
                            evidence_db_ids = line_split[2]
                            label = self.get_label(
                                proteins_dict[prot_id].description)
                            # make the row of current protein
                            protein_row = [
                                prot_id, label,
                                str(proteins_dict[prot_id].seq),
                                len(str(proteins_dict[prot_id].seq)),
                                interpro_ids, evidence_db_ids
                            ]
                            writer.writerow(protein_row)
                            num_proteins_with_domains += 1
                            proteins_dict.pop(
                                prot_id
                            )  # remove found protein from whole proteins dictionary
            num_remain_proteins = len(proteins_dict)
            assert num_all_proteins == num_proteins_with_domains + num_remain_proteins, "AssertionError: total num of proteins should be equal to proteins with domains + proteins without domains."
            SeqIO.write(proteins_dict.values(), remaining_seq_file, "fasta")
            print("num of NEW proteins: {}".format(num_all_proteins))
            print("num of NEW proteins with found domains: {}".format(
                num_proteins_with_domains))
            print(
                "num of remaining proteins with not found domains: {}".format(
                    len(proteins_dict)))
Exemplo n.º 13
0
def main(unused_args):
    assert len(unused_args) == 1, unused_args
    setup_experiment()

    mnist_ds = mnist.read_data_sets(
        FLAGS.data_dir, dtype=tf.float32, reshape=False)
    test_ds = getattr(mnist_ds, FLAGS.dataset)

    images = test_ds.images
    labels = test_ds.labels
    if FLAGS.sort_labels:
        ys_indices = np.argsort(labels)
        images = images[ys_indices]
        labels = labels[ys_indices]

    # loaded discriminator number of classes and dims
    img_shape = [None, 1, 28, 28]
    num_classes = FLAGS.num_classes

    X = tf.placeholder(tf.float32, shape=img_shape, name='X')
    y = tf.placeholder(tf.int32, shape=[None], name='y')
    y_onehot = tf.one_hot(y, num_classes)

    # model
    model = create_model(FLAGS, name=FLAGS.model_name)

    def test_model(x, **kwargs):
        return model(x, train=False, **kwargs)

    # wrap model for carlini method
    def carlini_predict(x):
        # carlini requires inputs in [-0.5, 0.5] but network trained on
        # [0, 1] inputs
        x = (2 * x + 1) / 2
        x = tf.transpose(x, [0, 3, 1, 2])
        return test_model(x)['logits']
    carlini_model = AttributeDict({'num_channels': 1,
                                   'image_size': 28,
                                   'num_labels': num_classes,
                                   'predict': carlini_predict})

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.25)
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        # carlini l2 attack
        carlini_l2 = CarliniL2(sess, carlini_model,
                               batch_size=FLAGS.carlini_batch_size,
                               max_iterations=FLAGS.carlini_max_iter,
                               confidence=FLAGS.carlini_confidence,
                               binary_search_steps=FLAGS.carlini_binary_steps,
                               targeted=False)

        def generate_carlini_l2(images, onehot_labels):
            return from_carlini_images(
                carlini_l2.attack(
                    to_carlini_images(images), onehot_labels))
        X_ca_l2 = tf.py_func(generate_carlini_l2, [X, y_onehot], tf.float32)
        X_ca_l2 = tf.reshape(X_ca_l2, tf.shape(X))

        filter_index_l2 = tf.py_func(non_converged_indices, [X_ca_l2], tf.int32)
        filter_index_l2.set_shape([FLAGS.batch_size])
        X_f_l2 = tf.gather(X, filter_index_l2)
        X_ca_f_l2 = tf.gather(X_ca_l2, filter_index_l2)

        # outputs
        outs_x = test_model(X)
        outs_x_ca_l2 = test_model(X_ca_l2)

        # l2 carlini results
        l2_ca = tf.sqrt(tf.reduce_sum((X_ca_l2 - X)**2, axis=(1, 2, 3)))
        l2_ca_norm = l2_ca / tf.sqrt(tf.reduce_sum(X**2, axis=(1, 2, 3)))
        conf_ca = tf.reduce_mean(tf.reduce_max(outs_x_ca_l2['prob'], axis=1))
        l2_ca_f = tf.sqrt(tf.reduce_sum((X_ca_f_l2 - X_f_l2)**2, axis=(1, 2, 3)))
        l2_ca_f_norm = l2_ca_f / tf.sqrt(tf.reduce_sum(X_f_l2**2, axis=(1, 2, 3)))
        smoothness_ca_f = tf.reduce_mean(tf.image.total_variation(X_ca_f_l2))

        nll = tf.reduce_mean(tf.losses.softmax_cross_entropy(y_onehot, outs_x['logits']))
        err = 1 - slim.metrics.accuracy(outs_x['pred'], y)
        err_ca_l2 = 1 - slim.metrics.accuracy(outs_x_ca_l2['pred'], y)
        total_processed_l2 = tf.shape(X_f_l2)[0]

        metrics = OrderedDict([('nll', nll),
                               ('err', err),
                               ('err_ca_l2', err_ca_l2),
                               ('l2_ca', tf.reduce_mean(l2_ca)),
                               ('l2_ca_norm', tf.reduce_mean(l2_ca_norm)),
                               ('conf_ca', conf_ca),
                               ('l2_ca_f', tf.reduce_mean(l2_ca_f)),
                               ('l2_ca_f_norm', tf.reduce_mean(l2_ca_f_norm)),
                               ('smoothness_ca', smoothness_ca_f),
                               ('total_processed_l2', total_processed_l2)])
        metrics_mean, metrics_upd = register_metrics(metrics)
        tf.summary.histogram('y_data', y)
        tf.summary.histogram('y_hat', outs_x['pred'])
        tf.summary.histogram('y_adv', outs_x_ca_l2['pred'])

        # start
        tf.local_variables_initializer().run()
        model_loader = tf.train.Saver(tf.model_variables())
        model_filename = ('model' if FLAGS.restore_epoch_index is None else
                          'model-%d' % FLAGS.restore_epoch_index)
        model_path = os.path.join(FLAGS.load_dir, 'chks', model_filename)
        model_loader.restore(sess, model_path)

        summary_writer = tf.summary.FileWriter(FLAGS.working_dir, sess.graph)
        summaries = tf.summary.merge_all()

        if FLAGS.generate_summary:
            logging.info("Generating samples...")
            summary_images, summary_labels = select_balanced_subset(
                images, labels, num_classes, num_classes)
            summary_images = summary_images.transpose((0, 3, 1, 2))
            err_l2, summary_ca_l2_imgs = (
                sess.run([err_ca_l2, X_ca_l2],
                         {X: summary_images, y: summary_labels}))
            if not np.allclose(err_l2, 1):
                logging.warn("Generated samples are not all mistakes: %f", err_l2)
            save_path = os.path.join(FLAGS.samples_dir, 'orig.png')
            save_images(summary_images, save_path)
            save_path = os.path.join(FLAGS.samples_dir, 'carlini_l2.png')
            save_images(summary_ca_l2_imgs, save_path)
        else:
            logging.debug("Skipping summary...")

        logging.info("Starting...")
        # Carlini is slow. Sample random subset
        if FLAGS.num_examples > 0 and FLAGS.num_examples < images.shape[0]:
            indices = np.arange(images.shape[0])
            np.random.shuffle(indices)
            images = images[indices[:FLAGS.num_examples]]
            labels = labels[indices[:FLAGS.num_examples]]

        X_hat_np = []
        test_iterator = batch_iterator(images, labels, FLAGS.batch_size, shuffle=False)
        start_time = time.time()
        for batch_index, (images, labels) in enumerate(test_iterator, 1):
            ca_l2_imgs, summary = sess.run(
                [X_ca_l2, summaries, metrics_upd],
                {X: images, y: labels})[:2]
            X_hat_np.extend(ca_l2_imgs)
            summary_writer.add_summary(summary, batch_index)

            save_path = os.path.join(FLAGS.samples_dir, 'b%d-ca_l2.png' % batch_index)
            save_images(ca_l2_imgs, save_path)
            save_path = os.path.join(FLAGS.samples_dir, 'b%d-orig.png' % batch_index)
            save_images(images, save_path)

            if batch_index % FLAGS.print_frequency == 0:
                str_bfr = six.StringIO()
                str_bfr.write("Batch {} [{:.2f}s]:".format(batch_index, time.time() - start_time))
                print_results_str(str_bfr, metrics.keys(), sess.run(metrics_mean))
                logging.info(str_bfr.getvalue()[:-1])

        X_hat_np = np.asarray(X_hat_np)
        save_path = os.path.join(FLAGS.adv_data_dir, 'mnist_%s.npz' % FLAGS.dataset)
        np.savez(save_path, X_hat_np)
        logging.info("Saved adv_data to %s", save_path)
        str_bfr = six.StringIO()
        str_bfr.write("Test results [{:.2f}s]:".format(time.time() - start_time))
        print_results_str(str_bfr, metrics.keys(), sess.run(metrics_mean))
        logging.info(str_bfr.getvalue()[:-1])