示例#1
0
def create_top_correlations_html(data, fitter, fits, scores, regions, n_top=None):
    if n_top is None:
        n_top = len(scores)
        
    basedir = join(results_dir(), fit_results_relative_path(data,fitter))
    ensure_dir(basedir)
    gene_dir = 'gene-subplot'
    series_dir = 'gene-region-fits'

    def key_func(score):
        g,r,pval,lst_R2 = score
        return r
    scores.sort(key=key_func)
    top_genes = [g for g,r,pval,lst_R2 in scores[:n_top]]
    top_scores = {g:r for g,r,pval,lst_R2 in scores[:n_top]}
    top_pvals = {g:pval for g,r,pval,lst_R2 in scores[:n_top]}
    
    def get_onset_time(fit):
        a,h,mu,_ = fit.theta
        age = age_scaler.unscale(mu)
        txt = 'onset = {:.3g} years'.format(age)
        cls = ''
        return txt,cls
    
    create_html(
        data, fitter, fits, basedir, gene_dir, series_dir,
        gene_names = top_genes, 
        region_names = regions,
        extra_columns = [('r',top_scores),('p-value',top_pvals)],
        extra_fields_per_fit = [get_onset_time],
        b_inline_images = True,
        b_R2_dist = False, 
        ttl = 'Fit for genes with top Spearman correlations',
        filename = 'top-gradual-maturation',
    )
def flush():
    prints = []

    for name, vals in _since_last_flush.items():
        prints.append("{}\t{}".format(name, np.mean(list(vals.values()))))
        _since_beginning[name].update(vals)
        #print(_since_beginning[name])

        x_vals = np.sort(list(_since_beginning[name].keys()))
        y_vals = [_since_beginning[name][x] for x in x_vals]

        plt.clf()
        plt.plot(x_vals, y_vals)
        plt.xlabel('iteration')
        plt.ylabel(name)
        fpath = os.path.join('debug', name.replace(' ', '_') + '.png')
        base_dir = os.path.dirname(fpath)
        ensure_dir(base_dir)
        plt.savefig(fpath)

    print("iter {}\t{}".format(_iter[0], "\t".join(prints)))
    _since_last_flush.clear()

    with open('debug/log.pkl', 'wb') as f:
        pickle.dump(dict(_since_beginning), f, -1)  #pickle.HIGHEST_PROTOCOL)
示例#3
0
def plot_and_save_all_genes(data, fitter, fits, dirname,
                            show_change_distributions):
    ensure_dir(dirname)
    to_plot = []
    genes = set(
    )  # use the genes from the fits and not from 'data' to support sharding (k_of_n)
    for ds_fits in fits.itervalues():
        for g, r in ds_fits.iterkeys():
            genes.add(g)
    for g in sorted(genes):
        genedir = join(dirname, g[:g.index(cfg.exon_separator)]
                       )  #does nothing if gene name does not contain the exon
        ensure_dir(genedir)
        filename = join(genedir, '{}.png'.format(g))
        if isfile(filename):
            print 'Figure already exists for gene {}. skipping...'.format(g)
            continue
        region_series_fits = _extract_gene_data(data, g, fits)
        if show_change_distributions:
            bin_centers = fits.change_distribution_params.bin_centers
        else:
            bin_centers = None
        to_plot.append((g, region_series_fits, filename, bin_centers))
    pool = Parallel(_plot_genes_job)
    pool(pool.delay(*args) for args in to_plot)
示例#4
0
文件: gan.py 项目: zmonoid/defensegan
 def generate_image(self, training_iter):
     samples = self.sess.run(self.fixed_noise_samples)
     debug_dir = self.checkpoint_dir.replace('output', 'debug')
     ensure_dir(debug_dir)
     tflib.save_images.save_images(
         (samples.reshape((len(samples), 64, 64, 3)) + 1) / (2.0),
         os.path.join(debug_dir, 'samples_{}.png'.format(training_iter)))
def save_file(filename, lines):
    lines.append('')
    txt = '\n'.join(lines)
    ensure_dir(dirname(filename))
    print 'Saving to {}'.format(filename)
    with open(filename, 'w') as f:
        f.write(txt)
示例#6
0
def sendToDetection(frame):
    '''
        sendToDetection(frame)
        ---

    '''
    ensure_dir(pic_dir) # check if dir exists
    tmp_pic = pic_dir + 'pic_' + '0' +'.jpg'

    detected = []
    detected_faces = []
    
    cv2.imwrite(tmp_pic, frame) # Saving frame for a while

    try:
        # time.sleep(wait_time)
        detected = CF.face.detect(tmp_pic)
    except Exception as e:
        print(colored(e, color='grey'))
    print('#######################################################')
    print(colored(' Detection...' + spc(31) + '[   ' , color='white') + colored('OK', color='green') + colored('   ]' , color='white'))
        
    for detected_face in detected:
        detected_faces.append({'faceId' : detected_face['faceId'], 'faceRectangle': detected_face['faceRectangle']}) # Remembering face IDs and rectangles 
        print(colored(' Face detected:   ', color='white') + colored(detected_face['faceId'], color='yellow'))
        print('#######################################################') 
    return detected_faces
示例#7
0
def _consolidate(dct_res, base_filename, k_of_n, found_keys_not_in_main_file):
    filename = _cache_filename(base_filename, k_of_n)

    # write the updated main file    
    if found_keys_not_in_main_file:
        if cfg.verbosity > 0:
            print 'Writing back consolidated fit file...'
        ensure_dir(dirname(filename))
        with open(filename,'w') as f:
            pickle.dump(dct_res,f)
    
    if cfg.verbosity > 0:
        print 'Deleting any partial fit files...'
    if k_of_n is None:
        # it's the main file - delete all k_of_n files and the batches dir
        batchdir = _batch_dir(base_filename)
        if isdir(batchdir):
            shutil.rmtree(batchdir)
        partial_files = set(glob(filename + '*')) - {filename}
        for filename in partial_files:
            os.remove(filename)
    else:
        # it's a shard - delete just the batches for that file
        base = _batch_base_filename(base_filename,k_of_n)
        batch_filenames = glob(base + '*')
        for filename in batch_filenames:
            os.remove(filename)
def save_file(filename, lines):
    lines.append("")
    txt = "\n".join(lines)
    ensure_dir(dirname(filename))
    print "Saving to {}".format(filename)
    with open(filename, "w") as f:
        f.write(txt)
示例#9
0
def _consolidate(dct_res, base_filename, k_of_n, found_keys_not_in_main_file):
    filename = _cache_filename(base_filename, k_of_n)

    # write the updated main file
    if found_keys_not_in_main_file:
        if cfg.verbosity > 0:
            print 'Writing back consolidated fit file...'
        ensure_dir(dirname(filename))
        with open(filename, 'w') as f:
            pickle.dump(dct_res, f)

    if cfg.verbosity > 0:
        print 'Deleting any partial fit files...'
    if k_of_n is None:
        # it's the main file - delete all k_of_n files and the batches dir
        batchdir = _batch_dir(base_filename)
        if isdir(batchdir):
            shutil.rmtree(batchdir)
        partial_files = set(glob(filename + '*')) - {filename}
        for filename in partial_files:
            os.remove(filename)
    else:
        # it's a shard - delete just the batches for that file
        base = _batch_base_filename(base_filename, k_of_n)
        batch_filenames = glob(base + '*')
        for filename in batch_filenames:
            os.remove(filename)
示例#10
0
    def test_batch(self):
        """Tests the image batch generator."""
        output_dir = os.path.join(self.debug_dir, 'test_batch')
        ensure_dir(output_dir)

        img, target = self.train_data_gen().next()
        img = img.reshape([self.batch_size] + self.image_dim)
        save_images_files(img / 255.0, output_dir=output_dir, labels=target)
def create_top_genes_html(data, fitter, fits, scores, regions, n_top=None, filename_suffix=''):
    if n_top is None:
        n_top = len(scores)
        
    basedir = join(results_dir(), fit_results_relative_path(data,fitter))
    ensure_dir(basedir)
    gene_dir = 'gene-subplot'
    series_dir = 'gene-region-fits'

    def key_func(score):
        g,pval,qval = score
        return pval
    scores.sort(key=key_func)
    top_genes = [g for g,pval,qval in scores[:n_top]]
    top_pvals = {g:pval for g,pval,qval in scores[:n_top]}
    top_qvals = {g:qval for g,pval,qval in scores[:n_top]}
    
    n = len(scores)
    n05 = len([g for g,pval,qval in scores if qval < 0.05])
    n01 = len([g for g,pval,qval in scores if qval < 0.01])
    top_text = """\
<pre>
one sided t-test: {regions[0]} < {regions[1]}
{n05}/{n} q-values < 0.05
{n01}/{n} q_values < 0.01
</pre>
""".format(**locals())
    
    def get_onset_time(fit):
        a,h,mu,_ = fit.theta
        age = age_scaler.unscale(mu)
        return 'onset = {:.3g} years'.format(age)
        
    def get_onset_dist(fit):
        mu_vals = fit.theta_samples[2,:]
        mu = mu_vals.mean()
        vLow,vHigh = np.percentile(mu_vals, (20,80))
        mu = age_scaler.unscale(mu)
        vLow = age_scaler.unscale(vLow)
        vHigh = age_scaler.unscale(vHigh)
        txt = 'onset reestimate (mean [20%, 80%]) = {:.3g} [{:.3g},{:.3g}]'.format(mu,vLow,vHigh)
        cls = ''
        return txt,cls
    
    create_html(
        data, fitter, fits, basedir, gene_dir, series_dir,
        gene_names = top_genes, 
        region_names = regions,
        extra_columns = [('p-value',top_pvals), ('q-value',top_qvals)],
        extra_fields_per_fit = [get_onset_time, get_onset_dist],
        b_inline_images = True,
        inline_image_size = '30%',
        b_R2_dist = False, 
        ttl = 'Fit for genes with top t-test scores',
        top_text = top_text,
        filename = 'gradual-maturation-t-test' + filename_suffix,
    )
示例#12
0
def do_gene_fits(data, gene, fitter, filename, b_show):
    fig = plot_gene(data,gene)
    if filename is None:
        ensure_dir(results_dir())
        filename = join(results_dir(), 'fit.png')
    print 'Saving figure to {}'.format(filename)
    save_figure(fig, filename)
    if b_show:
        plt.show(block=True)
示例#13
0
def _save_batch(dct_updates, base_filename, k_of_n, i):
    filename = _batch_base_filename(base_filename, k_of_n) + str(i)
    ensure_dir(dirname(filename))    

    # if there's already a file by that name, merge its contents
    if isfile(filename):
        dct_existing = _read_one_cache_file(filename, st_keys=None, is_batch=True)
        dct_updates.update(dct_existing)

    with open(filename,'w') as f:
        pickle.dump(dct_updates,f)
示例#14
0
def main(cfg, argv=None):
    FLAGS = tf.app.flags.FLAGS
    GAN = dataset_gan_dict[FLAGS.dataset_name]

    gan = GAN(cfg=cfg, test_mode=True)
    gan.load_generator()
    # Setting test time reconstruction hyper parameters.
    [tr_rr, tr_lr, tr_iters] = [FLAGS.rec_rr, FLAGS.rec_lr, FLAGS.rec_iters]
    if FLAGS.defense_type.lower() != 'none':
        if FLAGS.rec_path and FLAGS.defense_type == 'defense_gan':

            # Extract hyperparameters from reconstruction path.
            if FLAGS.rec_path:
                train_param_re = re.compile('recs_rr(.*)_lr(.*)_iters(.*)')
                [tr_rr, tr_lr, tr_iters] = \
                    train_param_re.findall(FLAGS.rec_path)[0]
                gan.rec_rr = int(tr_rr)
                gan.rec_lr = float(tr_lr)
                gan.rec_iters = int(tr_iters)
        elif FLAGS.defense_type == 'defense_gan':
            assert FLAGS.online_training or not FLAGS.train_on_recs

    if FLAGS.override:
        gan.rec_rr = int(tr_rr)
        gan.rec_lr = float(tr_lr)
        gan.rec_iters = int(tr_iters)

    # Setting the results directory.
    results_dir, result_file_name = _get_results_dir_filename(gan)

    # Result file name. The counter ensures we are not overwriting the
    # results.
    counter = 0
    temp_fp = str(counter) + '_' + result_file_name
    results_dir = os.path.join(results_dir, FLAGS.results_dir)
    temp_final_fp = os.path.join(results_dir, temp_fp)
    while os.path.exists(temp_final_fp):
        counter += 1
        temp_fp = str(counter) + '_' + result_file_name
        temp_final_fp = os.path.join(results_dir, temp_fp)
    result_file_name = temp_fp
    sub_result_path = os.path.join(results_dir, result_file_name)

    accuracies = measure_gan(gan,
                             rec_data_path=FLAGS.rec_path,
                             probe_size=FLAGS.probe_size,
                             calc_real_data_is=FLAGS.calc_real_data_is)

    ensure_dir(results_dir)

    with open(sub_result_path, 'a') as f:
        f.writelines([str(acc) + ' ' for acc in accuracies])
        f.write('\n')
        print('[*] saved accuracy in {}'.format(sub_result_path))
示例#15
0
def do_one_fit(series, fitter, loo, filename, b_show):
    if fitter is not None:
        theta, sigma, LOO_predictions,_ = fitter.fit(series.ages, series.single_expression, loo=loo)
        fig = plot_one_series(series, fitter.shape, theta, LOO_predictions)
    else:
        fig = plot_one_series(series)
    if filename is None:
        ensure_dir(results_dir())
        filename = join(results_dir(), 'fit.png')
    save_figure(fig, filename, print_filename=True)
    if b_show:
        plt.show(block=True)
示例#16
0
    def save(self, prefixes=None, global_step=None, checkpoint_dir=None):
        if global_step is None:
            global_step = self.global_step
        if checkpoint_dir is None:
            checkpoint_dir = self._set_checkpoint_dir

        ensure_dir(checkpoint_dir)
        self._initialize_saver(prefixes)
        self.saver.save(self.sess,
                        os.path.join(checkpoint_dir, self.model_save_name),
                        global_step=global_step)
        print('Saved at iter {} to {}'.format(self.sess.run(global_step),
                                              checkpoint_dir))
示例#17
0
def _save_batch(dct_updates, base_filename, k_of_n, i):
    filename = _batch_base_filename(base_filename, k_of_n) + str(i)
    ensure_dir(dirname(filename))

    # if there's already a file by that name, merge its contents
    if isfile(filename):
        dct_existing = _read_one_cache_file(filename,
                                            st_keys=None,
                                            is_batch=True)
        dct_updates.update(dct_existing)

    with open(filename, 'w') as f:
        pickle.dump(dct_updates, f)
示例#18
0
def fit_serveral_genes(series, fitter, loo, filename, b_show):
    if fitter is not None:
        theta, L, LOO_predictions,_ = fitter.fit(series.ages, series.expression, loo=loo)
        print 'L = {}'.format(L)
        fig = plot_series(series, fitter.shape, theta, LOO_predictions)
    else:
        fig = plot_series(series)
    if filename is None:
        ensure_dir(results_dir())
        filename = join(results_dir(), 'fits.png')
    print 'Saving figure to {}'.format(filename)
    save_figure(fig, filename)
    if b_show:
        plt.show(block=True)
示例#19
0
def save_figure(fig, filename, b_close=False, b_square=True, show_frame=False, under_results=False, print_filename=False):
    if under_results:
        dirname = results_dir()
        filename = join(dirname,filename)
        ensure_dir(os.path.dirname(filename))
    if cfg.verbosity >= 1 or print_filename:
        print 'Saving figure to {}'.format(filename)
    figure_size_x = cfg.default_figure_size_x_square if b_square else cfg.default_figure_size_x
    fig.set_size_inches(figure_size_x, cfg.default_figure_size_y)
    if show_frame:
        facecolor = cfg.default_figure_facecolor
    else:
        facecolor = 'white'
    fig.savefig(filename, facecolor=facecolor, dpi=cfg.default_figure_dpi)
    if b_close:
        plt.close(fig)
示例#20
0
def fit_serveral_genes(series, fitter, loo, filename, b_show):
    if fitter is not None:
        theta, L, LOO_predictions, _ = fitter.fit(series.ages,
                                                  series.expression,
                                                  loo=loo)
        print 'L = {}'.format(L)
        fig = plot_series(series, fitter.shape, theta, LOO_predictions)
    else:
        fig = plot_series(series)
    if filename is None:
        ensure_dir(results_dir())
        filename = join(results_dir(), 'fits.png')
    print 'Saving figure to {}'.format(filename)
    save_figure(fig, filename)
    if b_show:
        plt.show(block=True)
示例#21
0
def create_top_correlations_html(data,
                                 fitter,
                                 fits,
                                 scores,
                                 regions,
                                 n_top=None):
    if n_top is None:
        n_top = len(scores)

    basedir = join(results_dir(), fit_results_relative_path(data, fitter))
    ensure_dir(basedir)
    gene_dir = 'gene-subplot'
    series_dir = 'gene-region-fits'

    def key_func(score):
        g, r, pval, lst_R2 = score
        return r

    scores.sort(key=key_func)
    top_genes = [g for g, r, pval, lst_R2 in scores[:n_top]]
    top_scores = {g: r for g, r, pval, lst_R2 in scores[:n_top]}
    top_pvals = {g: pval for g, r, pval, lst_R2 in scores[:n_top]}

    def get_onset_time(fit):
        a, h, mu, _ = fit.theta
        age = age_scaler.unscale(mu)
        txt = 'onset = {:.3g} years'.format(age)
        cls = ''
        return txt, cls

    create_html(
        data,
        fitter,
        fits,
        basedir,
        gene_dir,
        series_dir,
        gene_names=top_genes,
        region_names=regions,
        extra_columns=[('r', top_scores), ('p-value', top_pvals)],
        extra_fields_per_fit=[get_onset_time],
        b_inline_images=True,
        b_R2_dist=False,
        ttl='Fit for genes with top Spearman correlations',
        filename='top-gradual-maturation',
    )
示例#22
0
def plot_and_save_all_exons_from_series(fits, exons_dir, series_dir):
    to_plot = []
    ensure_dir(exons_dir)
    keys = set()
    for ds_fits in fits.itervalues():
        for g, r in ds_fits.iterkeys():
            keys.add((g[:g.index(cfg.exon_separator)], r))
    for g, r in keys:
        gene_dir = join(exons_dir, g)
        ensure_dir(gene_dir)
        filename = join(gene_dir, '{}-{}.png'.format(g, r))
        if isfile(filename):
            print 'Figure already exists for gene {} in region {}. skipping...'.format(
                g, r)
            continue
        to_plot.append((g, r, filename, series_dir))
    pool = Parallel(_plot_exons_from_series_job)
    pool(pool.delay(*args) for args in to_plot)
示例#23
0
def build_shards(src_dir,
                 save_dir,
                 src_file,
                 tgt_file,
                 vocab,
                 shard_size,
                 feat_ext,
                 mode='train',
                 feats=None):
    src_shards = split_corpus(src_file, shard_size)
    tgt_shards = split_corpus(tgt_file, shard_size)

    ensure_dir(save_dir)

    shard_index = 0
    for src_shard, tgt_shard in zip(src_shards, tgt_shards):
        logger.info('Building %s shard %d' % (mode, shard_index))
        audio_paths = [os.path.join(src_dir, p.strip()) for p in src_shard]
        assert all([os.path.exists(p) for p in audio_paths]), \
            "following audio files not found: %s" % \
            ' '.join([p.strip() for p in audio_paths if not os.path.exists(p)])
        targets = [t.strip() for t in tgt_shard]

        src_tgt_pairs = list(
            zip(audio_paths, targets, cycle([feat_ext]), cycle([vocab])))

        with Pool(50) as p:
            result = list(
                tqdm(p.imap(_worker, src_tgt_pairs), total=len(src_tgt_pairs)))
            result = [r for r in result if r is not None]
            audio_feats, transcriptions, indices = zip(*result)

        shard = {
            'src': np.asarray(audio_feats),
            'tgt': np.asarray(transcriptions),
            'indices':
            np.asarray([np.asarray(x).reshape(-1, 1) for x in indices]),
            'feats': feats
        }

        shard_path = os.path.join(save_dir, '%s.%05d.pt' % (mode, shard_index))
        logger.info('Saving shard %d to %s' % (shard_index, shard_path))
        torch.save(shard, shard_path)
        shard_index += 1
    def save_ds(self):
        """Reconstructs the images of the config's dataset with the
        generator."""
        if self.dataset_name == 'cifar':
            splits = ['train', 'dev']
        else:
            splits = ['train', 'dev', 'test']
        for split in splits:
            output_dir = os.path.join('data', 'cache',
                                      '{}_pkl'.format(self.dataset_name),
                                      split)
            if self.debug:
                output_dir += '_debug'

            ensure_dir(output_dir)
            orig_imgs_pkl_path = os.path.join(output_dir,
                                              'feats.pkl'.format(split))

            if os.path.exists(orig_imgs_pkl_path) and not self.test_again:
                with open(orig_imgs_pkl_path) as f:
                    all_recs = cPickle.load(f)
                    could_load = True
                    print('[#] Dataset is already saved.')
                    return

            gen_func = getattr(self, '{}_gen_test'.format(split))
            all_targets = []
            orig_imgs = []
            ctr = 0
            for images, targets in gen_func():
                ctr += 1
                transformed_images = self.sess.run(self.real_data_test,
                                                   feed_dict={
                                                       self.real_data_test_pl:
                                                           images})
                orig_imgs.append(transformed_images)
                all_targets.append(targets)
            orig_imgs = np.concatenate(orig_imgs).reshape(
                [-1] + self.image_dim)
            all_targets = np.concatenate(all_targets)
            with open(orig_imgs_pkl_path, 'w') as f:
                cPickle.dump(orig_imgs, f, cPickle.HIGHEST_PROTOCOL)
                cPickle.dump(all_targets, f, cPickle.HIGHEST_PROTOCOL)
示例#25
0
def save_fits_and_create_html(data, fitter, fits=None, basedir=None, 
                              do_genes=True, do_series=True, do_hist=True, do_html=True, only_main_html=False,
                              k_of_n=None, 
                              use_correlations=False, correlations=None,
                              show_change_distributions=False,
                              html_kw=None,
                              figure_kw=None):
    if fits is None:
        fits = get_all_fits(data,fitter,k_of_n)
    if basedir is None:
        basedir = join(results_dir(), fit_results_relative_path(data,fitter))
        if use_correlations:
            basedir = join(basedir,'with-correlations')
    if html_kw is None:
        html_kw = {}
    if figure_kw is None:
        figure_kw = {}
    print 'Writing HTML under {}'.format(basedir)
    ensure_dir(basedir)
    gene_dir = 'gene-subplot'
    series_dir = 'gene-region-fits'
    correlations_dir = 'gene-correlations'
    scores_dir = 'score_distributions'
    if do_genes and not only_main_html: # relies on the sharding of the fits respecting gene boundaries
        plot_and_save_all_genes(data, fitter, fits, join(basedir,gene_dir), show_change_distributions)
    if do_series and not only_main_html:
        plot_and_save_all_series(data, fitter, fits, join(basedir,series_dir), use_correlations, show_change_distributions, figure_kw)
    if do_hist and k_of_n is None and not only_main_html:
        create_score_distribution_html(fits, use_correlations, join(basedir,scores_dir))
    if do_html and k_of_n is None:
        link_to_correlation_plots = use_correlations and correlations is not None
        if link_to_correlation_plots and not only_main_html:
            plot_and_save_all_gene_correlations(data, correlations, join(basedir,correlations_dir))
        dct_pathways = load_17_pathways_breakdown()
        pathway_genes = set.union(*dct_pathways.values())
        data_genes = set(data.gene_names)
        missing = pathway_genes - data_genes
        b_pathways = len(missing) < len(pathway_genes)/2 # simple heuristic to create pathways only if we have most of the genes (currently 61 genes are missing)
        create_html(
            data, fitter, fits, basedir, gene_dir, series_dir, scores_dir, correlations_dir=correlations_dir,
            use_correlations=use_correlations, link_to_correlation_plots=link_to_correlation_plots, 
            b_pathways=b_pathways, **html_kw
        )
示例#26
0
def plot_and_save_all_genes(data, fitter, fits, dirname, show_change_distributions):
    ensure_dir(dirname)
    to_plot = []
    genes = set() # use the genes from the fits and not from 'data' to support sharding (k_of_n)
    for ds_fits in fits.itervalues():
        for g,r in ds_fits.iterkeys():
            genes.add(g)
    for g in sorted(genes):
        filename = join(dirname, '{}.png'.format(g))
        if isfile(filename):
            print 'Figure already exists for gene {}. skipping...'.format(g)
            continue
        region_series_fits = _extract_gene_data(data,g,fits)
        if show_change_distributions:
            bin_centers = fits.change_distribution_params.bin_centers
        else:
            bin_centers = None
        to_plot.append((g,region_series_fits,filename, bin_centers))
    pool = Parallel(_plot_genes_job)
    pool(pool.delay(*args) for args in to_plot)
示例#27
0
def get_onset_times(data, fitter, R2_threshold, b_force=False):
    filename = join(cache_dir(),fit_results_relative_path(data,fitter) + '.pkl')
    if isfile(filename):
        print 'Loading onset distribution from {}'.format(filename)
        with open(filename) as f:
            bin_edges, change_vals = pickle.load(f)
    else:
        print 'Computing...'
        fits = get_all_fits(data, fitter)        
        thetas = [fit.theta for fit in iterate_fits(fits, R2_threshold=R2_threshold)]
        stages = [stage.scaled(age_scaler) for stage in dev_stages]
        low = min(stage.from_age for stage in stages)
        high = max(stage.to_age for stage in stages) 
        bin_edges, change_vals = compute_change_distribution(fitter.shape, thetas, low, high, n_bins=50)    

        print 'Saving result to {}'.format(filename)
        ensure_dir(dirname(filename))   
        with open(filename,'w') as f:
            pickle.dump((bin_edges,change_vals),f)
    return bin_edges, change_vals
示例#28
0
def create_score_distribution_html(fits, use_correlations, dirname):
    ensure_dir(dirname)
    with interactive(False):
        hist_filename = 'R2-hist.png'
        fig = plot_score_distribution(fits,use_correlations)
        save_figure(fig, join(dirname,hist_filename), b_close=True)
        
        if use_correlations:
            scatter_filename = 'comparison-scatter.png'
            fig = plot_score_comparison_scatter_for_correlations(fits)
            save_figure(fig, join(dirname,scatter_filename), b_close=True)
            
            delta_hist_filename = 'R2-delta-hist.png'
            fig = plot_score_improvement_histogram_for_correlations(fits)
            save_figure(fig, join(dirname,delta_hist_filename), b_close=True)
                        
    image_size = "50%"
    from jinja2 import Template
    import shutil
    html = Template("""
<html>
<head>
    <link rel="stylesheet" type="text/css" href="score-distribution.css">
</head>
<body>
<center><H1>R2 Score Distribution</H1></center>

<a href="{{hist_filename}}"> <img src="{{hist_filename}}" height="{{image_size}}"></a>

{% if use_correlations %}
    <a href="{{delta_hist_filename}}"> <img src="{{delta_hist_filename}}" height="{{image_size}}"></a>
    <a href="{{scatter_filename}}"> <img src="{{scatter_filename}}" height="{{image_size}}"></a>
{% endif %}

</body>
</html>    
""").render(**locals())
    filename = join(dirname,'scores.html')
    with open(filename, 'w') as f:
        f.write(html)
    shutil.copy(join(resources_dir(),'score-distribution.css'), dirname)
示例#29
0
def plot_and_save_all_series(data,
                             fitter,
                             fits,
                             dirname,
                             use_correlations,
                             show_change_distributions,
                             exons_layout=False,
                             figure_kw=None):
    ensure_dir(dirname)
    to_plot = []
    for dsfits in fits.itervalues():
        for (g, r), fit in dsfits.iteritems():

            genedir = join(
                dirname,
                g[:g.index(cfg.exon_separator)] if exons_layout else g)
            ensure_dir(genedir)
            filename = join(genedir, 'fit-{}-{}.png'.format(g, r))
            if isfile(filename):
                print 'Figure already exists for {}@{}. skipping...'.format(
                    g, r)
                continue
            series = data.get_one_series(g, r)
            if show_change_distributions and hasattr(
                    fit, 'change_distribution_weights'):
                change_distribution = Bunch(
                    centers=fits.change_distribution_params.bin_centers,
                    weights=fit.change_distribution_weights,
                )
            else:
                change_distribution = None
            to_plot.append((series, fit, filename, use_correlations,
                            change_distribution, figure_kw))
    if cfg.parallel_run_locally:
        for args in to_plot:
            _plot_series_job(*args)
    else:
        pool = Parallel(_plot_series_job)
        pool(pool.delay(*args) for args in to_plot)
示例#30
0
def save_figure(fig,
                filename,
                b_close=False,
                b_square=True,
                show_frame=False,
                under_results=False,
                print_filename=False):
    if under_results:
        dirname = results_dir()
        filename = join(dirname, filename)
        ensure_dir(os.path.dirname(filename))
    if cfg.verbosity >= 1 or print_filename:
        print 'Saving figure to {}'.format(filename)
    figure_size_x = cfg.default_figure_size_x_square if b_square else cfg.default_figure_size_x
    fig.set_size_inches(figure_size_x, cfg.default_figure_size_y)
    if show_frame:
        facecolor = cfg.default_figure_facecolor
    else:
        facecolor = 'white'
    fig.savefig(filename, facecolor=facecolor, dpi=cfg.default_figure_dpi)
    if b_close:
        plt.close(fig)
def save_ds(gan_model):

    splits = ['train', 'dev', 'test']

    for split in splits:
        output_dir = os.path.join('data', 'cache',
                                  '{}_pkl'.format(gan_model.dataset_name),
                                  split)
        if gan_model.debug:
            output_dir += '_debug'

        ensure_dir(output_dir)
        orig_imgs_pkl_path = os.path.join(output_dir,
                                          'feats.pkl'.format(split))

        if os.path.exists(orig_imgs_pkl_path) and not gan_model.test_again:
            with open(orig_imgs_pkl_path) as f:
                all_recs = cPickle.load(f)
                could_load = True
                print('[#] Dataset is already saved.')
                return

        gen_func = getattr(gan_model, '{}_gen_test'.format(split))
        all_targets = []
        orig_imgs = []
        ctr = 0
        for images, targets in gen_func():
            ctr += 1
            transformed_images = gan_model.sess.run(
                gan_model.real_data_test,
                feed_dict={gan_model.real_data_test_pl: images})
            orig_imgs.append(transformed_images)
            all_targets.append(targets)
        orig_imgs = np.concatenate(orig_imgs).reshape([-1] +
                                                      gan_model.image_dim)
        all_targets = np.concatenate(all_targets)
        with open(orig_imgs_pkl_path, 'w') as f:
            cPickle.dump(orig_imgs, f, cPickle.HIGHEST_PROTOCOL)
            cPickle.dump(all_targets, f, cPickle.HIGHEST_PROTOCOL)
示例#32
0
def plot_and_save_all_exons(data, fitter, fits, dirname):
    ensure_dir(dirname)
    to_plot = []
    genes, regions = set(), set()
    for ds_fits in fits.itervalues():
        for g, r in ds_fits.iterkeys():
            genes.add(g[:g.index(cfg.exon_separator)])
            regions.add(r)
    for g in sorted(genes):
        for r in sorted(regions):
            gene_dir = join(dirname, g)
            ensure_dir(gene_dir)
            filename = join(gene_dir, '{}-{}.png'.format(g, r))
            if isfile(filename):
                print 'Figure already exists for gene {} in region {}. skipping...'.format(
                    g, r)
                continue
            exons_series_fits = _extract_exons_data(data, g, r, fits)
            if not np.count_nonzero(exons_series_fits):
                continue
            to_plot.append((g, r, exons_series_fits, filename))
    pool = Parallel(_plot_exons_job)
    pool(pool.delay(*args) for args in to_plot)
    def _set_checkpoint_dir(self):
        """Sets the directory containing snapshots of the model."""

        self.cfg_file = self.cfg['cfg_path']
        if 'cfg.yml' in self.cfg_file:
            ckpt_dir = os.path.dirname(self.cfg_file)

        else:
            ckpt_dir = os.path.join(
                self.output_dir,
                self.cfg_file.replace('experiments/cfgs/',
                                      '').replace('cfg.yml',
                                                  '').replace('.yml', ''))
            if not self.test_mode:
                postfix = ''
                ignore_list = ['dataset', 'cfg_file', 'batch_size']
                if hasattr(self, 'cfg'):
                    if self.cfg is not None:
                        for prop in self.default_properties:
                            if prop in ignore_list:
                                continue

                            if prop.upper() in self.cfg.keys():
                                self_val = getattr(self, prop)
                                if self_val is not None:
                                    if getattr(self,
                                               prop) != self.cfg[prop.upper()]:
                                        postfix += '-{}={}'.format(
                                            prop, self_val).replace('.', '_')

                ckpt_dir += postfix
            ensure_dir(ckpt_dir)

        self.checkpoint_dir = ckpt_dir
        self.debug_dir = self.checkpoint_dir.replace('output', 'debug')
        ensure_dir(self.debug_dir)
示例#34
0
def plot_and_save_all_series(data, fitter, fits, dirname, use_correlations, show_change_distributions, figure_kw=None):
    ensure_dir(dirname)
    to_plot = []
    for dsfits in fits.itervalues():
        for (g,r),fit in dsfits.iteritems():
            filename = join(dirname, 'fit-{}-{}.png'.format(g,r))
            if isfile(filename):
                print 'Figure already exists for {}@{}. skipping...'.format(g,r)
                continue
            series = data.get_one_series(g,r)
            if show_change_distributions and hasattr(fit, 'change_distribution_weights'):
                change_distribution = Bunch(
                    centers = fits.change_distribution_params.bin_centers,
                    weights = fit.change_distribution_weights,
                )
            else:
                change_distribution = None
            to_plot.append((series,fit,filename,use_correlations, change_distribution, figure_kw))
    if cfg.parallel_run_locally:
        for args in to_plot:
            _plot_series_job(*args)
    else:
        pool = Parallel(_plot_series_job)
        pool(pool.delay(*args) for args in to_plot)
示例#35
0
def create_score_distribution_html(fits, use_correlations, dirname):
    ensure_dir(dirname)
    with interactive(False):
        hist_filename = 'R2-hist.png'
        fig = plot_score_distribution(fits, use_correlations)
        save_figure(fig, join(dirname, hist_filename), b_close=True)

        if use_correlations:
            scatter_filename = 'comparison-scatter.png'
            fig = plot_score_comparison_scatter_for_correlations(fits)
            save_figure(fig, join(dirname, scatter_filename), b_close=True)

            delta_hist_filename = 'R2-delta-hist.png'
            fig = plot_score_improvement_histogram_for_correlations(fits)
            save_figure(fig, join(dirname, delta_hist_filename), b_close=True)

    image_size = "50%"

    import shutil
    html = get_jinja_env().get_template('R2.jinja').render(**locals())
    filename = join(dirname, 'scores.html')
    with open(filename, 'w') as f:
        f.write(html)
    shutil.copy(join(resources_dir(), 'score-distribution.css'), dirname)
示例#36
0
def main(cfg, argv=None):
    FLAGS = tf.app.flags.FLAGS
    GAN = dataset_gan_dict[FLAGS.dataset_name]

    gan = GAN(cfg=cfg, test_mode=True)
    gan.load_generator()
    # Setting test time reconstruction hyper parameters.
    [tr_rr, tr_lr, tr_iters] = [FLAGS.rec_rr, FLAGS.rec_lr, FLAGS.rec_iters]
    if FLAGS.defense_type.lower() != 'none':
        if FLAGS.rec_path and FLAGS.defense_type == 'defense_gan':

            # extract hyper parameters from reconstruction path.
            if FLAGS.rec_path:
                train_param_re = re.compile('recs_rr(.*)_lr(.*)_iters(.*)')
                [tr_rr, tr_lr, tr_iters] = \
                    train_param_re.findall(FLAGS.rec_path)[0]
                gan.rec_rr = int(tr_rr)
                gan.rec_lr = float(tr_lr)
                gan.rec_iters = int(tr_iters)
        elif FLAGS.defense_type == 'defense_gan':
            assert FLAGS.online_training or not FLAGS.train_on_recs

    if FLAGS.override:
        gan.rec_rr = int(tr_rr)
        gan.rec_lr = float(tr_lr)
        gan.rec_iters = int(tr_iters)

    # Setting the reuslts directory
    results_dir, result_file_name = _get_results_dir_filename(gan)

    # Result file name. The counter makes sure we are not overwriting the
    # results.
    counter = 0
    temp_fp = str(counter) + '_' + result_file_name
    results_dir = os.path.join(results_dir, FLAGS.results_dir)
    temp_final_fp = os.path.join(results_dir, temp_fp)
    while os.path.exists(temp_final_fp):
        counter += 1
        temp_fp = str(counter) + '_' + result_file_name
        temp_final_fp = os.path.join(results_dir, temp_fp)
    result_file_name = temp_fp
    sub_result_path = os.path.join(results_dir, result_file_name)

    accuracies = blackbox(gan,
                          rec_data_path=FLAGS.rec_path,
                          batch_size=FLAGS.batch_size,
                          learning_rate=FLAGS.learning_rate,
                          nb_epochs=FLAGS.nb_epochs,
                          holdout=FLAGS.holdout,
                          data_aug=FLAGS.data_aug,
                          nb_epochs_s=FLAGS.nb_epochs_s,
                          lmbda=FLAGS.lmbda,
                          online_training=FLAGS.online_training,
                          train_on_recs=FLAGS.train_on_recs,
                          defense_type=FLAGS.defense_type)

    ensure_dir(results_dir)

    with open(sub_result_path, 'a') as f:
        f.writelines([
            str(accuracies[x]) + ' '
            for x in ['bbox', 'sub', 'bbox_on_sub_adv_ex']
        ])
        f.write('\n')
        print('[*] saved accuracy in {}'.format(sub_result_path))

    if 'roc_info' in accuracies.keys():  # For attack detection.
        pkl_result_path = sub_result_path.replace('.txt', '_roc.pkl')
        with open(pkl_result_path, 'w') as f:
            cPickle.dump(accuracies['roc_info'], f, cPickle.HIGHEST_PROTOCOL)
            print('[*] saved roc_info in {}'.format(sub_result_path))
示例#37
0
def save_matfile(mdict, filename):
    ensure_dir(dirname(filename))
    print 'Saving to {}'.format(filename)
    savemat(filename, mdict, oned_as='column')
示例#38
0
def main(cfg, *args):
    FLAGS = tf.app.flags.FLAGS

    rng = np.random.RandomState([11, 24, 1990])
    tf.set_random_seed(11241990)

    gan = gan_from_config(cfg, True)

    results_dir = 'results/clean/{}'.format(gan.dataset_name)
    ensure_dir(results_dir)

    sess = gan.sess
    gan.load_model()

    # use test split
    train_images, train_labels, test_images, test_labels = get_cached_gan_data(
        gan, test_on_dev=False, orig_data_flag=True)

    x_shape = [None] + list(train_images.shape[1:])
    images_pl = tf.placeholder(tf.float32,
                               shape=[BATCH_SIZE] +
                               list(train_images.shape[1:]))
    labels_pl = tf.placeholder(tf.float32,
                               shape=[BATCH_SIZE] + [train_labels.shape[1]])

    if FLAGS.num_tests > 0:
        test_images = test_images[:FLAGS.num_tests]
        test_labels = test_labels[:FLAGS.num_tests]

    if FLAGS.num_train > 0:
        train_images = train_images[:FLAGS.num_train]
        train_labels = train_labels[:FLAGS.num_train]

    train_params = {
        'nb_epochs': 10,
        'batch_size': BATCH_SIZE,
        'learning_rate': 0.001
    }

    eval_params = {'batch_size': BATCH_SIZE}

    # train classifier for mnist, fmnist
    if gan.dataset_name in ['mnist', 'f-mnist']:
        model = model_a(input_shape=x_shape, nb_classes=train_labels.shape[1])
        preds_train = model.get_logits(images_pl, dropout=True)

        model_train(sess,
                    images_pl,
                    labels_pl,
                    preds_train,
                    train_images,
                    train_labels,
                    args=train_params,
                    rng=rng,
                    init_all=False)

    elif gan.dataset_name == 'cifar-10':
        pre_model = Model('classifiers/model/',
                          tiny=False,
                          mode='eval',
                          sess=sess)
        model = DefenseWrapper(pre_model, 'logits')

    elif gan.dataset_name == 'celeba':
        # TODO
        raise NotImplementedError

    model.add_rec_model(gan, batch_size=BATCH_SIZE)
    preds_eval = model.get_logits(images_pl)

    # calculate norms
    num_dims = len(images_pl.get_shape())
    avg_inds = list(range(1, num_dims))
    reconstruct = gan.reconstruct(images_pl, batch_size=BATCH_SIZE)

    # We use L2 loss for GD steps
    diff_op = tf.reduce_mean(tf.square(reconstruct - images_pl), axis=avg_inds)

    acc, mse, roc_info = model_eval_gan(sess,
                                        images_pl,
                                        labels_pl,
                                        preds_eval,
                                        None,
                                        test_images=test_images,
                                        test_labels=test_labels,
                                        args=eval_params,
                                        diff_op=diff_op)
    # Logging
    logfile = open(os.path.join(results_dir, 'acc.txt'), 'a+')
    msg = 'lr_{}_iters_{}, {}\n'.format(gan.rec_lr, gan.rec_iters, acc)
    logfile.writelines(msg)
    logfile.close()

    logfile = open(os.path.join(results_dir, 'mse.txt'), 'a+')
    msg = 'lr_{}_iters_{}, {}\n'.format(gan.rec_lr, gan.rec_iters, mse)
    logfile.writelines(msg)
    logfile.close()

    pickle_filename = os.path.join(
        results_dir, 'roc_lr_{}_iters_{}.pkl'.format(gan.rec_lr,
                                                     gan.rec_iters))
    with open(pickle_filename, 'w') as f:
        cPickle.dump(roc_info, f, cPickle.HIGHEST_PROTOCOL)
        print('[*] saved roc_info in {}'.format(pickle_filename))

    return [acc, mse]
def save_matfile(mdict, filename):
    ensure_dir(dirname(filename))
    print 'Saving to {}'.format(filename)
    savemat(filename, mdict, oned_as='column')
示例#40
0
def plot_and_save_all_gene_correlations(data, correlations, dirname):
    ensure_dir(dirname)
    for region in data.region_names:
        fig = plot_gene_correlations_single_region(correlations[region], region, data.gene_names)
        save_figure(fig, join(dirname,'{}.png'.format(region)), b_close=True)
示例#41
0
def main(cfg, argv=None):
    FLAGS = tf.app.flags.FLAGS

    gan = DefenseGANBase(cfg=cfg, test_mode=True)
    # Setting test time reconstruction hyper parameters.
    [tr_rr, tr_lr, tr_iters] = [FLAGS.rec_rr, FLAGS.rec_lr, FLAGS.rec_iters]
    if FLAGS.defense_type.lower() != 'none':
        if FLAGS.defense_type == 'defense_gan':

            if 'GENERATOR_INIT_PATH' in cfg:
                gan = DefenseGANv2(get_generator_fn(cfg['DATASET_NAME']),
                                   cfg=cfg,
                                   test_mode=True)
            else:
                gan = DefenseGANBase(cfg=cfg, test_mode=True)

            gan.load_model()

            # Extract hyperparameters from reconstruction path.
            if FLAGS.rec_path is not None:
                train_param_re = re.compile('recs_rr(.*)_lr(.*)_iters(.*)')
                [tr_rr, tr_lr, tr_iters] = \
                    train_param_re.findall(FLAGS.rec_path)[0]
                gan.rec_rr = int(tr_rr)
                gan.rec_lr = float(tr_lr)
                gan.rec_iters = int(tr_iters)
            else:
                assert FLAGS.online_training or not FLAGS.train_on_recs

    if FLAGS.override:
        gan.rec_rr = int(tr_rr)
        gan.rec_lr = float(tr_lr)
        gan.rec_iters = int(tr_iters)

    # Setting the results directory.
    results_dir, result_file_name = _get_results_dir_filename(gan)

    # Result file name. The counter ensures we are not overwriting the
    # results.
    counter = 0
    temp_fp = str(counter) + '_' + result_file_name
    results_dir = os.path.join(results_dir, FLAGS.results_dir)
    temp_final_fp = os.path.join(results_dir, temp_fp)
    while os.path.exists(temp_final_fp):
        counter += 1
        temp_fp = str(counter) + '_' + result_file_name
        temp_final_fp = os.path.join(results_dir, temp_fp)
    result_file_name = temp_fp
    sub_result_path = os.path.join(results_dir, result_file_name)

    accuracies = whitebox(
        gan,
        rec_data_path=FLAGS.rec_path,
        batch_size=FLAGS.batch_size,
        learning_rate=FLAGS.learning_rate,
        nb_epochs=FLAGS.nb_epochs,
        eps=FLAGS.fgsm_eps,
        online_training=FLAGS.online_training,
        defense_type=FLAGS.defense_type,
        num_tests=FLAGS.num_tests,
        attack_type=FLAGS.attack_type,
        num_train=FLAGS.num_train,
    )

    ensure_dir(results_dir)

    with open(sub_result_path, 'a') as f:
        f.writelines([str(accuracies[i]) + ' ' for i in range(2)])
        f.write('\n')
        print('[*] saved accuracy in {}'.format(sub_result_path))

    if accuracies[2]:  # For attack detection.
        pkl_result_path = sub_result_path.replace('.txt', '_roc.pkl')
        with open(pkl_result_path, 'w') as f:
            cPickle.dump(accuracies[2], f, cPickle.HIGHEST_PROTOCOL)
            print('[*] saved roc_info in {}'.format(pkl_result_path))
示例#42
0
def blackbox(gan,
             rec_data_path=None,
             batch_size=128,
             learning_rate=0.001,
             nb_epochs=10,
             holdout=150,
             data_aug=6,
             nb_epochs_s=10,
             lmbda=0.1,
             online_training=False,
             train_on_recs=False,
             test_on_dev=True,
             defense_type='none'):
    """MNIST tutorial for the black-box attack from arxiv.org/abs/1602.02697
    
    Args:
        train_start: index of first training set example
        train_end: index of last training set example
        test_start: index of first test set example
        test_end: index of last test set example
        defense_type: Type of defense against blackbox attacks
    
    Returns:
        a dictionary with:
             * black-box model accuracy on test set
             * substitute model accuracy on test set
             * black-box model accuracy on adversarial examples transferred
               from the substitute model
    """
    FLAGS = flags.FLAGS

    # Set logging level to see debug information.
    set_log_level(logging.WARNING)

    # Dictionary used to keep track and return key accuracies.
    accuracies = {}

    # Create TF session.
    adv_training = False
    if defense_type:
        if defense_type == 'defense_gan' and gan:
            sess = gan.sess
            gan_defense_flag = True
        else:
            gan_defense_flag = False
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            sess = tf.Session(config=config)
        if 'adv_tr' in defense_type:
            adv_training = True
    else:
        gan_defense_flag = False
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)

    train_images, train_labels, test_images, test_labels = \
        get_cached_gan_data(gan, test_on_dev, orig_data_flag=True)

    x_shape, classes = list(train_images.shape[1:]), train_labels.shape[1]
    nb_classes = classes

    type_to_models = {
        'A': model_a,
        'B': model_b,
        'C': model_c,
        'D': model_d,
        'E': model_e,
        'F': model_f,
        'Q': model_q,
        'Z': model_z
    }

    bb_model = type_to_models[FLAGS.bb_model](
        input_shape=[None] + x_shape,
        nb_classes=train_labels.shape[1],
    )
    sub_model = type_to_models[FLAGS.sub_model](
        input_shape=[None] + x_shape,
        nb_classes=train_labels.shape[1],
    )

    if FLAGS.debug:
        train_images = train_images[:20 * batch_size]
        train_labels = train_labels[:20 * batch_size]
        debug_dir = os.path.join('debug', 'blackbox', FLAGS.debug_dir)
        ensure_dir(debug_dir)
        x_debug_test = test_images[:batch_size]

    # Initialize substitute training set reserved for adversary
    images_sub = test_images[:holdout]
    labels_sub = np.argmax(test_labels[:holdout], axis=1)

    # Redefine test set as remaining samples unavailable to adversaries
    if FLAGS.num_tests > 0:
        test_images = test_images[:FLAGS.num_tests]
        test_labels = test_labels[:FLAGS.num_tests]

    test_images = test_images[holdout:]
    test_labels = test_labels[holdout:]

    # Define input and output TF placeholders

    if FLAGS.image_dim[0] == 3:
        FLAGS.image_dim = [
            FLAGS.image_dim[1], FLAGS.image_dim[2], FLAGS.image_dim[0]
        ]

    images_tensor = tf.placeholder(tf.float32, shape=[None] + x_shape)
    labels_tensor = tf.placeholder(tf.float32, shape=(None, classes))

    rng = np.random.RandomState([11, 24, 1990])
    tf.set_random_seed(11241990)

    train_images_bb, train_labels_bb, test_images_bb, test_labels_bb = \
        train_images, train_labels, test_images, \
        test_labels

    cur_gan = None

    if defense_type:
        if 'gan' in defense_type:
            # Load cached dataset reconstructions.
            if online_training and not train_on_recs:
                cur_gan = gan
            elif not online_training and rec_data_path:
                train_images_bb, train_labels_bb, test_images_bb, \
                test_labels_bb = get_cached_gan_data(
                    gan, test_on_dev, orig_data_flag=False)
            else:
                assert not train_on_recs

        if FLAGS.debug:
            train_images_bb = train_images_bb[:20 * batch_size]
            train_labels_bb = train_labels_bb[:20 * batch_size]

        # Prepare the black_box model.
        prep_bbox_out = prep_bbox(sess,
                                  images_tensor,
                                  labels_tensor,
                                  train_images_bb,
                                  train_labels_bb,
                                  test_images_bb,
                                  test_labels_bb,
                                  nb_epochs,
                                  batch_size,
                                  learning_rate,
                                  rng=rng,
                                  gan=cur_gan,
                                  adv_training=adv_training,
                                  cnn_arch=bb_model)
    else:
        prep_bbox_out = prep_bbox(sess,
                                  images_tensor,
                                  labels_tensor,
                                  train_images_bb,
                                  train_labels_bb,
                                  test_images_bb,
                                  test_labels_bb,
                                  nb_epochs,
                                  batch_size,
                                  learning_rate,
                                  rng=rng,
                                  gan=cur_gan,
                                  adv_training=adv_training,
                                  cnn_arch=bb_model)

    model, bbox_preds, accuracies['bbox'] = prep_bbox_out

    # Train substitute using method from https://arxiv.org/abs/1602.02697
    print("Training the substitute model.")
    reconstructed_tensors = tf.stop_gradient(
        gan.reconstruct(images_tensor,
                        batch_size=batch_size,
                        reconstructor_id=1))
    model_sub, preds_sub = train_sub(
        sess,
        images_tensor,
        labels_tensor,
        model(reconstructed_tensors),
        images_sub,
        labels_sub,
        nb_classes,
        nb_epochs_s,
        batch_size,
        learning_rate,
        data_aug,
        lmbda,
        rng=rng,
        substitute_model=sub_model,
    )

    accuracies['sub'] = 0
    # Initialize the Fast Gradient Sign Method (FGSM) attack object.
    fgsm_par = {
        'eps': FLAGS.fgsm_eps,
        'ord': np.inf,
        'clip_min': 0.,
        'clip_max': 1.
    }
    if gan:
        if gan.dataset_name == 'celeba':
            fgsm_par['clip_min'] = -1.0

    fgsm = FastGradientMethod(model_sub, sess=sess)

    # Craft adversarial examples using the substitute.
    eval_params = {'batch_size': batch_size}
    x_adv_sub = fgsm.generate(images_tensor, **fgsm_par)

    if FLAGS.debug and gan is not None:  # To see some qualitative results.
        reconstructed_tensors = gan.reconstruct(x_adv_sub,
                                                batch_size=batch_size,
                                                reconstructor_id=2)

        x_rec_orig = gan.reconstruct(images_tensor,
                                     batch_size=batch_size,
                                     reconstructor_id=3)
        x_adv_sub_val = sess.run(x_adv_sub,
                                 feed_dict={
                                     images_tensor: x_debug_test,
                                     K.learning_phase(): 0
                                 })
        sess.run(tf.local_variables_initializer())
        x_rec_debug_val, x_rec_orig_val = sess.run(
            [reconstructed_tensors, x_rec_orig],
            feed_dict={
                images_tensor: x_debug_test,
                K.learning_phase(): 0
            })

        save_images_files(x_adv_sub_val, output_dir=debug_dir, postfix='adv')

        postfix = 'gen_rec'
        save_images_files(x_rec_debug_val,
                          output_dir=debug_dir,
                          postfix=postfix)
        save_images_files(x_debug_test, output_dir=debug_dir, postfix='orig')
        save_images_files(x_rec_orig_val,
                          output_dir=debug_dir,
                          postfix='orig_rec')
        return

    if gan_defense_flag:
        reconstructed_tensors = gan.reconstruct(
            x_adv_sub,
            batch_size=batch_size,
            reconstructor_id=4,
        )

        num_dims = len(images_tensor.get_shape())
        avg_inds = list(range(1, num_dims))
        diff_op = tf.reduce_mean(tf.square(x_adv_sub - reconstructed_tensors),
                                 axis=avg_inds)

        outs = model_eval_gan(sess,
                              images_tensor,
                              labels_tensor,
                              predictions=model(reconstructed_tensors),
                              test_images=test_images,
                              test_labels=test_labels,
                              args=eval_params,
                              diff_op=diff_op,
                              feed={K.learning_phase(): 0})

        accuracies['bbox_on_sub_adv_ex'] = outs[0]
        accuracies['roc_info'] = outs[1]
        print('Test accuracy of oracle on adversarial examples generated '
              'using the substitute: ' + str(outs[0]))
    else:
        accuracy = model_eval(sess,
                              images_tensor,
                              labels_tensor,
                              model(x_adv_sub),
                              test_images,
                              test_labels,
                              args=eval_params,
                              feed={K.learning_phase(): 0})
        print('Test accuracy of oracle on adversarial examples generated '
              'using the substitute: ' + str(accuracy))
        accuracies['bbox_on_sub_adv_ex'] = accuracy

    return accuracies
示例#43
0
if __name__ == '__main__':
    disable_all_warnings()
    parser = get_common_parser()
    parser.add_argument('--shape2', required=True, help='The shape to compare against', choices=allowed_shape_names())
    parser.add_argument('--scaling2', help='The scaling used when fitting shape2. Default: none', choices=allowed_scaler_names())
    parser.add_argument('--sigma_prior2', help='Prior to use for 1/sigma when fitting shape2. Default: None', choices=get_allowed_priors(is_sigma=True))
    parser.add_argument('--priors2', help='The priors used for theta when fitting shape2. Default: None', choices=get_allowed_priors())
    parser.add_argument('--filename', help='Where to save the figure. Default: results/comparison.png')
    parser.add_argument('--show', help='Show figure and wait before exiting', action='store_true')
    parser.add_argument('--ndiffs', type=int, default=5, help='Number of top diffs to show. Default=5.')
    args = parser.parse_args()
    data1, fitter1 = process_common_inputs(args)    
    data2 = get_data_from_args(args.dataset, args.pathway, args.from_age, args.scaling2, args.shuffle)
    fitter2 = get_fitter_from_args(args.shape2, args.priors2, args.sigma_prior2)

    fits1 = get_all_fits(data1,fitter1)
    fits2 = get_all_fits(data2,fitter2)

    print_diff_points(data1,fitter1,fits1, data2,fitter2,fits2, args.ndiffs)

    fig = plot_comparison_scatter(data1,fitter1,fits1, data2,fitter2,fits2)

    filename = args.filename    
    if filename is None:
        ensure_dir(results_dir())
        filename = join(results_dir(), 'shape_comparison.png')
    save_figure(fig, filename)    

    if args.show:
        plt.show(block=True)
示例#44
0
文件: gan.py 项目: zmonoid/defensegan
    def reconstruct_dataset(self, ckpt_path=None, max_num=-1, max_num_load=-1):
        """Reconstructs the images of the config's dataset with the generator.
        """

        if not self.initialized:
            self.load_generator(ckpt_path=ckpt_path)

        splits = ['train', 'dev', 'test']

        rec = self.reconstruct(self.real_data_test)

        self.sess.run(tf.local_variables_initializer())
        rets = {}

        for split in splits:
            if max_num > 0:
                output_dir = os.path.join(
                    self.checkpoint_dir, 'recs_rr{:d}_lr{:.5f}_'
                    'iters{:d}_num{:d}'.format(self.rec_rr, self.rec_lr,
                                               self.rec_iters, max_num), split)
            else:
                output_dir = os.path.join(
                    self.checkpoint_dir, 'recs_rr{:d}_lr{:.5f}_'
                    'iters{:d}'.format(self.rec_rr, self.rec_lr,
                                       self.rec_iters), split)

            if self.debug:
                output_dir += '_debug'

            ensure_dir(output_dir)
            feats_path = os.path.join(output_dir, 'feats.pkl'.format(split))
            could_load = False
            try:
                if os.path.exists(feats_path) and not self.test_again:
                    with open(feats_path) as f:
                        all_recs = cPickle.load(f)
                        could_load = True
                        print('[#] Successfully loaded features.')
                else:
                    all_recs = []
            except Exception as e:
                all_recs = []
                print('[#] Exception loading features {}'.format(str(e)))

            gen_func = getattr(self, '{}_gen_test'.format(split))
            all_targets = []
            orig_imgs = []
            ctr = 0
            sti = time.time()

            # Pickle files per reconstructed image.
            pickle_out_dir = os.path.join(output_dir, 'pickles')
            ensure_dir(pickle_out_dir)
            single_feat_path_template = os.path.join(pickle_out_dir,
                                                     'rec_{:07d}_l{}.pkl')

            for images, targets in gen_func():
                batch_size = len(images)
                im_paths = [
                    single_feat_path_template.format(ctr * batch_size + i,
                                                     targets[i])
                    for i in range(batch_size)
                ]

                mn = max(max_num, max_num_load)

                if (mn > -1 and ctr *
                    (len(images)) > mn) or (self.debug and ctr > 2):
                    break

                batch_could_load = not self.test_again
                batch_rec_list = []

                for imp in im_paths:  # Load per image cached files.
                    try:
                        with open(imp) as f:
                            loaded_rec = cPickle.load(f)
                            batch_rec_list.append(loaded_rec)
                            # print('[-] Loaded batch {}'.format(ctr))
                    except:
                        batch_could_load = False
                        break

                if batch_could_load and not could_load:
                    recs = np.stack(batch_rec_list)
                    all_recs.append(recs)

                if not (could_load or batch_could_load):
                    self.sess.run(tf.local_variables_initializer())
                    recs = self.sess.run(
                        rec,
                        feed_dict={self.real_data_test_pl: images},
                    )
                    print('[#] t:{:.2f} batch: {:d} '.format(
                        time.time() - sti, ctr))
                    all_recs.append(recs)
                else:
                    print('[*] could load batch: {:d}'.format(ctr))

                if not batch_could_load and not could_load:
                    for i in range(len(recs)):
                        pkl_path = im_paths[i]
                        with open(pkl_path, 'w') as f:
                            cPickle.dump(recs[i],
                                         f,
                                         protocol=cPickle.HIGHEST_PROTOCOL)
                            #print('[*] Saved reconstruction for {}'.format(pkl_path))

                all_targets.append(targets)

                orig_transformed = self.sess.run(
                    self.real_data_test,
                    feed_dict={self.real_data_test_pl: images})

                orig_imgs.append(orig_transformed)
                ctr += 1
            if not could_load:
                all_recs = np.concatenate(all_recs)
                all_recs = all_recs.reshape([-1] + self.image_dim)

            orig_imgs = np.concatenate(orig_imgs).reshape([-1] +
                                                          self.image_dim)
            all_targets = np.concatenate(all_targets)

            if self.debug:
                save_images_files(all_recs,
                                  output_dir=output_dir,
                                  labels=all_targets)
                save_images_files(
                    (orig_imgs + min(0, orig_imgs.min()) /
                     (orig_imgs.max() - min(0, orig_imgs.min()))),
                    output_dir=output_dir,
                    labels=all_targets,
                    postfix='_orig')

            rets[split] = [all_recs, all_targets, orig_imgs]

        return rets