def main(args):

    # Set reproduciable random seed
    tf.set_random_seed(1234)

    # Directories
    # Get name
    split = FLAGS.load_dir.split('/')
    if split[-1]:
        name = split[-1]
    else:
        name = split[-2]

    # Get parent directory
    split = FLAGS.load_dir.split("/" + name)
    parent_dir = split[0]

    test_dir = '{}/{}/test'.format(parent_dir, name)
    test_summary_dir = test_dir + '/summary'

    # Clear the test log directory
    if (FLAGS.reset is True) and os.path.exists(test_dir):
        shutil.rmtree(test_dir)
    if not os.path.exists(test_summary_dir):
        os.makedirs(test_summary_dir)

    # Logger
    conf.setup_logger(logger_dir=test_dir, name="logger_test.txt")
    logger.info("name: " + name)
    logger.info("parent_dir: " + parent_dir)
    logger.info("test_dir: " + test_dir)
    if FLAGS.patch_path:
        logger.info("patch_path: " + FLAGS.patch_path)

    # Load hyperparameters from train run
    conf.load_or_save_hyperparams()

    # Get dataset hyperparameters
    logger.info('Using dataset: {}'.format(FLAGS.dataset))

    # Dataset
    dataset_size_test = conf.get_dataset_size_test(
        FLAGS.dataset
    ) if FLAGS.partition == "test" else conf.get_dataset_size_train(
        FLAGS.dataset)
    num_classes = conf.get_num_classes(FLAGS.dataset)
    create_inputs_test = conf.get_create_inputs(FLAGS.dataset,
                                                mode=FLAGS.partition)

    #----------------------------------------------------------------------------
    # GRAPH - TEST
    #----------------------------------------------------------------------------
    logger.info('BUILD TEST GRAPH')
    g_test = tf.Graph()
    with g_test.as_default():
        # Get global_step
        global_step = tf.train.get_or_create_global_step()

        num_batches_test = int(dataset_size_test / FLAGS.batch_size)

        # Get data
        input_dict = create_inputs_test()
        batch_x = input_dict['image']
        batch_labels = input_dict['label']

        # AG 10/12/2018: Split batch for multi gpu implementation
        # Each split is of size FLAGS.batch_size / FLAGS.num_gpus
        # See: https://github.com/naturomics/CapsNet-
        # Tensorflow/blob/master/dist_version/distributed_train.py
        splits_x = tf.split(axis=0,
                            num_or_size_splits=FLAGS.num_gpus,
                            value=batch_x)
        splits_labels = tf.split(axis=0,
                                 num_or_size_splits=FLAGS.num_gpus,
                                 value=batch_labels)

        # Build architecture
        build_arch = conf.get_dataset_architecture(FLAGS.dataset)
        # for baseline
        #build_arch = conf.get_dataset_architecture('baseline')

        #--------------------------------------------------------------------------
        # MULTI GPU - TEST
        #--------------------------------------------------------------------------
        # Calculate the logits for each model tower
        tower_logits = []
        tower_recon_losses = []
        reuse_variables = None
        with tf.device("/cpu:0"):
            scale_min_feed = tf.placeholder(tf.float32,
                                            shape=[],
                                            name="scale_min_feed")
            scale_max_feed = tf.placeholder(tf.float32,
                                            shape=[],
                                            name="scale_max_feed")
        patch_feed = None
        if FLAGS.patch_path:
            patch_feed = tf.placeholder(
                tf.float32,
                shape=batch_x.get_shape().as_list()[-3:],
                name="patch_feed")
        for i in range(FLAGS.num_gpus):
            with tf.device('/gpu:%d' % i):
                with tf.name_scope('tower_%d' % i) as scope:
                    with slim.arg_scope([slim.variable], device='/cpu:0'):
                        logits, recon_losses, patch_node = tower_fn(
                            build_arch,
                            splits_x[i],
                            scale_min_feed,
                            scale_max_feed,
                            patch_feed,
                            scope,
                            num_classes,
                            reuse_variables=reuse_variables,
                            is_train=False)

                    # Don't reuse variable for first GPU, but do reuse for others
                    reuse_variables = True
                    # Keep track of losses and logits across for each tower
                    tower_logits.append(logits)
                    tower_recon_losses.append(recon_losses)
        # Combine logits from all towers
        test_metrics = {}
        if not FLAGS.save_patch:
            test_logits = tf.concat(tower_logits, axis=0)
            test_preds = tf.argmax(test_logits, axis=-1)
            test_recon_losses = tf.concat(tower_recon_losses, axis=0)
            test_metrics = {
                'preds': test_preds,
                'labels': batch_labels,
                'recon_losses': test_recon_losses
            }
        if FLAGS.adv_patch:
            test_metrics['patch'] = patch_node

        # Reset and read operations for streaming metrics go here
        test_reset = {}
        test_read = {}

        # Saver
        saver = tf.train.Saver(max_to_keep=None)

        # Set summary op

        #--------------------------------------------------------------------------
        # SESSION - TEST
        #--------------------------------------------------------------------------
        #sess_test = tf.Session(
        #    config=tf.ConfigProto(allow_soft_placement=True,
        #                          log_device_placement=False),
        #    graph=g_test)
        # Perry: added in for RTX 2070 incompatibility workaround
        config = tf.ConfigProto(allow_soft_placement=True,
                                log_device_placement=False)
        config.gpu_options.allow_growth = True
        sess_test = tf.Session(config=config, graph=g_test)

        #sess_test.run(tf.local_variables_initializer())
        #sess_test.run(tf.global_variables_initializer())

        summary_writer = tf.summary.FileWriter(test_summary_dir,
                                               graph=sess_test.graph)

        ckpts_to_test = []
        load_dir_chechpoint = os.path.join(FLAGS.load_dir, "train",
                                           "checkpoint")

        # Evaluate the latest ckpt in dir
        if FLAGS.ckpt_name is None:
            latest_ckpt = tf.train.latest_checkpoint(load_dir_chechpoint)
            ckpts_to_test.append(latest_ckpt)

        # Evaluate all ckpts in dir
        elif FLAGS.ckpt_name == "all":
            # Get list of files in firectory and sort by date created
            filenames = os.listdir(load_dir_chechpoint)
            regex = re.compile(r'.*.index')
            filenames = filter(regex.search, filenames)
            data_ckpts = (os.path.join(load_dir_chechpoint, fn)
                          for fn in filenames)
            data_ckpts = ((os.stat(path), path) for path in data_ckpts)

            # regular files, insert creation date
            data_ckpts = ((stat[ST_CTIME], path) for stat, path in data_ckpts
                          if S_ISREG(stat[ST_MODE]))
            data_ckpts = sorted(data_ckpts)
            # remove ".index"
            ckpts_to_test = [path[:-6] for ctime, path in data_ckpts]

        # Evaluate ckpt specified by name
        else:
            ckpt_name = os.path.join(load_dir_chechpoint, FLAGS.ckpt_name)
            ckpts_to_test.append(ckpt_name)

        #--------------------------------------------------------------------------
        # MAIN LOOP
        #--------------------------------------------------------------------------
        # Run testing on checkpoints
        for ckpt in ckpts_to_test:
            saver.restore(sess_test, ckpt)

            if FLAGS.save_patch:
                out = sess_test.run(test_metrics['patch'])
                patch = out
                if patch.shape[-1] == 1:
                    patch = np.squeeze(patch, axis=-1)
                formatted = (patch * 255).astype('uint8')
                img = Image.fromarray(formatted)
                save_dir = os.path.join(FLAGS.storage, 'logs/', FLAGS.dataset,
                                        FLAGS.logdir)
                img.save(
                    os.path.join(FLAGS.load_dir, "test", "saved_patch.png"))
                return

            # Reset accumulators
            sess_test.run(test_reset)
            test_preds_vals = []
            test_labels_vals = []
            test_recon_losses_vals = []
            test_scales = []

            interval = 0.1 if FLAGS.adv_patch else 1
            for scale in np.arange(0, 1, interval):
                for i in range(num_batches_test):
                    feed_dict = {scale_min_feed: scale, scale_max_feed: scale}
                    if FLAGS.patch_path:
                        patch_dims = patch_feed.get_shape()
                        patch = np.asarray(Image.open(FLAGS.patch_path),
                                           dtype=np.float32)
                        if len(patch.shape) < 3:
                            patch = np.expand_dims(patch, axis=-1)
                        if patch_dims[-1] == 1:
                            patch = np.mean(patch, axis=-1, keepdims=True)
                        patch = patch / 255
                        feed_dict[patch_feed] = patch
                    out = sess_test.run([test_metrics], feed_dict=feed_dict)
                    test_metrics_v = out[0]
                    #ckpt_num = re.split('-', ckpt)[-1]
                    #logger.info('TEST ckpt-{}'.format(ckpt_num)
                    #    + ' bch-{:d}'.format(i)
                    #    )
                    test_preds_vals.append(test_metrics_v['preds'])
                    test_labels_vals.append(test_metrics_v['labels'])
                    test_recon_losses_vals.append(
                        test_metrics_v['recon_losses'])
                    test_scales.append(
                        np.full(test_metrics_v['preds'].shape,
                                fill_value=scale))

        logger.info('writing to csv')
        test_preds_vals = np.concatenate(test_preds_vals)
        test_labels_vals = np.concatenate(test_labels_vals)
        test_recon_losses_vals = np.concatenate(test_recon_losses_vals)
        test_scales = np.concatenate(test_scales)

        data = {
            'predictions': test_preds_vals,
            'labels': test_labels_vals,
            'reconstruction_losses': test_recon_losses_vals,
            'scales': test_scales
        }
        filename = "recon_losses.csv"
        if FLAGS.patch_path:
            filename = re.sub(
                '[^\w\-_]', '_',
                FLAGS.patch_path) + "_" + FLAGS.partition + ".csv"
        csv_save_path = os.path.join(FLAGS.load_dir, FLAGS.partition, filename)
        pd.DataFrame(data).to_csv(csv_save_path, index=False)
        logger.info('csv saved at ' + csv_save_path)
Exemple #2
0
def main():
    try:
        config.setup()
        parser = argparse.ArgumentParser()

        # This is gross and dirty and may break b/c it's gross and dirty
        parser._positionals.title = "Available commands"

        parser.add_argument('--dry-run', action='store_true',
                            help="show what would happen, "
                                 "but don't change anything")
        parser.add_argument('-L', '--id-list', action='store_true',
                            help='if possible, only output a list of matching '
                                 'Object Ids - Pipe this into other commands. '
                                 '(overrides --full and any other output)')
        parser.add_argument('-v', '--verbose', action='count', default=0,
                            help="amount of detail to display add vs (-vvvv) "
                                 "for maximum amount")
        parser.add_argument('--version', action='version',
                            version='%(prog)s ' + VERSION)

        subparsers = parser.add_subparsers(dest='command')

        # "config" cmd parser
        sp_cfg = subparsers.add_parser('config',
                                       help='manage configuration options')
        # add mutually exclusive group?
        sp_cfg.add_argument('-v', '--view', action='store_true',
                            help='view the current config data')
        sp_cfg.add_argument('-d', '--discover', action='store_true',
                            help='discover Tablos')

        # "library" cmd parser
        sp_lib = subparsers.add_parser('library',
                                       help='manage the local library '
                                            'of recordings')
        # add mutually exclusive group?
        sp_lib.add_argument('-b', '--build', action='store_true',
                            help='build library')
        sp_lib.add_argument('-v', '--view', action='store_true',
                            help='view library')
        sp_lib.add_argument('--stats', action='store_true',
                            help='search library')
        sp_lib.add_argument('--full', action='store_true',
                            help='dump/display full record details')
        sp_lib.add_argument('--incomplete', nargs="?", type=int, default=-2,
                            const=-1,
                            help='show what may be incomplete recordings. Add '
                                 'a number to limit to less than that percent'
                                 'of a full show. "Similar to --dupes, but '
                                 "tries to show the dupes that can't be"
                                 'combined into a possibly useful single '
                                 'recording.')
        sp_lib.add_argument('--dupes', action='store_true',
                            help='show what may be duplicate recordings. '
                                 "There's a good chance these are pieces of a "
                                 "partial recording, so you probably want to "
                                 "use --incomplete for cleanup")

        # search cmd parser
        sp_search = subparsers.add_parser('search',
                                          help='ways to search your library')

        sp_search.add_argument('-t', '--term',
                               help='search title/description for this')
        sp_search.add_argument('-a', '--after',
                               type=valid_date,
                               help='only recordings after this date')
        sp_search.add_argument('-b', '--before',
                               type=valid_date,
                               help='only recordings before this date')
        sp_search.add_argument('--limit', type=int,
                               help='only recordings in this state')
        sp_search.add_argument('--season', type=int,
                               help='episodes with this season')
        sp_search.add_argument('--episode', type=int,
                               help='episodes with this episode number')
        sp_search.add_argument('--state', action="append",
                               choices=['finished', 'failed', 'recording'],
                               help='only recordings in this state')
        sp_search.add_argument('--type', action="append",
                               choices=['episode', 'movie',
                                        'sport', 'programs'],
                               help='only include these recording types')
        sp_search.add_argument('--duration',
                               type=valid_duration,
                               help="recordings less than this length "
                                    "(28m, 10s, 1h. etc) - useful for culling "
                                    "bad recordings")
        sp_search.add_argument('--watched', action='store_true',
                               help='only include watched recordings')
        sp_search.add_argument('--full', action='store_true',
                               help='dump/display full record details')
        sp_search.add_argument('--tms-id',
                               help='select by TMS Id (probably unique)')
        sp_search.add_argument('--id', type=int,
                               help='select by Tablo Object Id'
                                    '(definitely unique)')

        # "copy" cmd parser
        sp_copy = subparsers.add_parser('copy',
                                        help='copy recordings somewhere')
        sp_copy.description = \
            'Pipe the -L output (object_id list) from a ' \
            '"search" or "library" command into this. ' \
            'Otherwise, use --infile/'

        sp_copy.add_argument('--infile', nargs='?',
                             type=argparse.FileType('r'),
                             help="file with list of ids to use, something "
                                  "like [867, 5309]",
                             default=sys.stdin)

        sp_copy.add_argument('--clobber', action='store_true',
                             default=False,
                             help='should we overwrite existing files?')

        # "delete" cmd parser
        sp_delete = subparsers.add_parser('delete',
                                          help='delete recordings from the '
                                          'Tablo device')
        sp_delete.description = \
            'Pipe the -L output (object_id list) from a ' \
            '"search" or "library" command into this. ' \
            'Otherwise, use --infile/'
        sp_delete.add_argument('--infile', nargs='?',
                               type=argparse.FileType('r'),
                               help="file with list of ids to use, something "
                                    "like [867, 5309]",
                               default=sys.stdin)

        sp_delete.add_argument('--yes', '--yyaaassss', action='store_true',
                               default=False,
                               help='This must be set to actually delete '
                                    'stuff')

        # args = parser.parse_args()
        args, unknown = parser.parse_known_args()

        if len(sys.argv) == 1:
            parser.print_help(sys.stderr)
            sys.exit(EXIT_CODE_ERROR)

        if args.verbose >= 3:
            log_level = logging.DEBUG
        elif args.verbose >= 2:
            log_level = logging.INFO
        elif args.verbose >= 1:
            log_level = logging.WARNING
        else:
            log_level = logging.CRITICAL

        config.built_ins['log_level'] = log_level
        config.setup_logger(log_level)

        config.built_ins['dry_run'] = args.dry_run

        if args.command == 'config':
            if args.view:
                config.view()

            elif args.discover:
                config.discover()

            else:
                sp_cfg.print_help(sys.stderr)

            return EXIT_CODE_OK

        if args.command == 'library':
            if args.build:
                library.build()
            elif args.view:
                library.view(args)
            elif args.stats:
                library.print_stats()
            elif args.dupes:
                library.print_dupes()
            elif args.incomplete and args.incomplete != -2:
                # TODO: all of what I've done here can't be the right way.
                if args.incomplete == -1:
                    args.incomplete = 100
                library.print_incomplete(args)
            else:
                sp_lib.print_help(sys.stderr)

            return EXIT_CODE_OK

        if args.command == 'search':
            if not (args.after or args.before or args.full
                    or args.state or args.term or args.type
                    or args.limit or args.watched
                    or args.episode or args.season
                    or args.tms_id or args.id or args.id_list
                    or args.duration
                    or search_unknown(unknown)
                    ):
                sp_search.print_help(sys.stderr)
                return EXIT_CODE_OK
            else:
                # nothing that looked like an arg and not blank. try it.
                if search_unknown(unknown):
                    args.term = " ".join(unknown)
                    search.search(args)
                else:
                    search.search(args)

            return EXIT_CODE_OK

        if args.command == 'copy':
            print("Gathering data, [ENTER] to quit")
            data = args.infile.readline()
            try:
                id_list = check_input(data)
            except ValueError:
                sp_copy.print_help(sys.stderr)
                return EXIT_CODE_ERROR
            export.copy(id_list, args)

            return EXIT_CODE_OK

        if args.command == 'delete':
            print("Gathering data, [ENTER] to quit")
            data = args.infile.readline().rstrip()
            try:
                id_list = check_input(data)
            except ValueError:
                print()
                sp_delete.print_help(sys.stderr)
                return EXIT_CODE_ERROR
            library.delete(id_list, args)

            return EXIT_CODE_OK

    except KeyboardInterrupt:
        return EXIT_CODE_ERROR  # pragma: no cover
Exemple #3
0
import asyncio
import os
from threading import Thread

from app import mail
from config import setup_logger
from flask import current_app
from flask_mail import Message
from sendgrid import SendGridAPIClient
from sendgrid.helpers.mail import Content, From, Mail, To

email_logger = setup_logger("email_logger", "email.log")


def send_async_email(app, msg):
    with app.app_context():
        mail.send(msg)


def send_email_localhost(
    subject, sender, sender_name, recipients, text_body, html_body
):
    msg = Message(subject, sender=sender, recipients=recipients)
    msg.body = text_body
    msg.html = html_body
    Thread(
        target=send_async_email, args=(current_app._get_current_object(), msg),
    ).start()


def send_real_email(
def main(args):
  
  # Set reproduciable random seed
  tf.set_random_seed(1234)
  
  # Directories
  # Get name
  split = FLAGS.load_dir.split('/')
  if split[-1]:
    name = split[-1]
  else:
    name = split[-2]
    
  # Get parent directory
  split = FLAGS.load_dir.split("/" + name)
  parent_dir = split[0]

  test_dir = '{}/{}/test'.format(parent_dir, name)
  test_summary_dir = test_dir + '/summary'

  # Clear the test log directory
  if (FLAGS.reset is True) and os.path.exists(test_dir):
    shutil.rmtree(test_dir) 
  if not os.path.exists(test_summary_dir):
    os.makedirs(test_summary_dir)
  
  # Logger
  conf.setup_logger(logger_dir=test_dir, name="logger_test.txt")
  logger.info("name: " + name)
  logger.info("parent_dir: " + parent_dir)
  logger.info("test_dir: " + test_dir)
  
  # Load hyperparameters from train run
  conf.load_or_save_hyperparams()
  
  # Get dataset hyperparameters
  logger.info('Using dataset: {}'.format(FLAGS.dataset))
  
  # Dataset
  dataset_size_test  = conf.get_dataset_size_test(FLAGS.dataset)
  num_classes        = conf.get_num_classes(FLAGS.dataset)
  create_inputs_test = conf.get_create_inputs(FLAGS.dataset, mode="test")

  
  #----------------------------------------------------------------------------
  # GRAPH - TEST
  #----------------------------------------------------------------------------
  logger.info('BUILD TEST GRAPH')
  g_test = tf.Graph()
  with g_test.as_default():
    # Get global_step
    global_step = tf.train.get_or_create_global_step()

    num_batches_test = int(dataset_size_test / FLAGS.batch_size)

    # Get data
    input_dict = create_inputs_test()
    batch_x = input_dict['image']
    batch_labels = input_dict['label']
    
    # AG 10/12/2018: Split batch for multi gpu implementation
    # Each split is of size FLAGS.batch_size / FLAGS.num_gpus
    # See: https://github.com/naturomics/CapsNet-
    # Tensorflow/blob/master/dist_version/distributed_train.py
    splits_x = tf.split(
        axis=0, 
        num_or_size_splits=FLAGS.num_gpus, 
        value=batch_x)
    splits_labels = tf.split(
        axis=0, 
        num_or_size_splits=FLAGS.num_gpus, 
        value=batch_labels)
    
    # Build architecture
    build_arch = conf.get_dataset_architecture(FLAGS.dataset)
    # for baseline
    #build_arch = conf.get_dataset_architecture('baseline')
    
    #--------------------------------------------------------------------------
    # MULTI GPU - TEST
    #--------------------------------------------------------------------------
    # Calculate the logits for each model tower
    tower_logits = []
    reuse_variables = None
    for i in range(FLAGS.num_gpus):
      with tf.device('/gpu:%d' % i):
        with tf.name_scope('tower_%d' % i) as scope:
          with slim.arg_scope([slim.variable], device='/cpu:0'):
            loss, logits = tower_fn(
                build_arch, 
                splits_x[i], 
                splits_labels[i], 
                scope, 
                num_classes, 
                reuse_variables=reuse_variables, 
                is_train=False)

          # Don't reuse variable for first GPU, but do reuse for others
          reuse_variables = True
          
          # Keep track of losses and logits across for each tower
          tower_logits.append(logits)
          
          # Loss for each tower
          tf.summary.histogram("test_logits", logits)
    
    # Combine logits from all towers
    logits = tf.concat(tower_logits, axis=0)
    
    # Calculate metrics
    test_loss = mod.spread_loss(logits, batch_labels)
    test_acc = met.accuracy(logits, batch_labels)
    
    # Prepare predictions and one-hot labels
    test_probs = tf.nn.softmax(logits=logits)
    test_labels_oh = tf.one_hot(batch_labels, num_classes)
    
    # Group metrics together
    # See: https://cs230-stanford.github.io/tensorflow-model.html
    test_metrics = {'loss' : test_loss,
                   'labels' : batch_labels, 
                   'labels_oh' : test_labels_oh,
                   'logits' : logits,
                   'probs' : test_probs,
                   'acc' : test_acc,
                   }
    
    # Reset and read operations for streaming metrics go here
    test_reset = {}
    test_read = {}
    
    tf.summary.scalar("test_loss", test_loss)
    tf.summary.scalar("test_acc", test_acc)
      
    # Saver
    saver = tf.train.Saver(max_to_keep=None)
    
    # Set summary op
    test_summary = tf.summary.merge_all()
    
  
    #--------------------------------------------------------------------------
    # SESSION - TEST
    #--------------------------------------------------------------------------
    #sess_test = tf.Session(
    #    config=tf.ConfigProto(allow_soft_placement=True, 
    #                          log_device_placement=False), 
    #    graph=g_test)
    # Perry: added in for RTX 2070 incompatibility workaround
    config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
    config.gpu_options.allow_growth = True
    sess_test = tf.Session(config=config, graph=g_test)

   
    
    #sess_test.run(tf.local_variables_initializer())
    #sess_test.run(tf.global_variables_initializer())

    summary_writer = tf.summary.FileWriter(
        test_summary_dir, 
        graph=sess_test.graph)


    ckpts_to_test = []
    load_dir_chechpoint = os.path.join(FLAGS.load_dir, "train", "checkpoint")
    
    # Evaluate the latest ckpt in dir
    if FLAGS.ckpt_name is None:
      latest_ckpt = tf.train.latest_checkpoint(load_dir_chechpoint)
      ckpts_to_test.append(latest_ckpt)

    # Evaluate all ckpts in dir  
    elif FLAGS.ckpt_name == "all":
      # Get list of files in firectory and sort by date created
      filenames = os.listdir(load_dir_chechpoint)
      regex = re.compile(r'.*.index')
      filenames = filter(regex.search, filenames)
      data_ckpts = (os.path.join(load_dir_chechpoint, fn) for fn in filenames)
      data_ckpts = ((os.stat(path), path) for path in data_ckpts)

      # regular files, insert creation date
      data_ckpts = ((stat[ST_CTIME], path) for stat, path in data_ckpts 
                    if S_ISREG(stat[ST_MODE]))
      data_ckpts= sorted(data_ckpts)
      # remove ".index"
      ckpts_to_test = [path[:-6] for ctime, path in data_ckpts]
        
    # Evaluate ckpt specified by name
    else:
      ckpt_name = os.path.join(load_dir_chechpoint, FLAGS.ckpt_name)
      ckpts_to_test.append(ckpt_name)    
      
      
    #--------------------------------------------------------------------------
    # MAIN LOOP
    #--------------------------------------------------------------------------
    # Run testing on checkpoints
    for ckpt in ckpts_to_test:
      saver.restore(sess_test, ckpt)
          
      # Reset accumulators
      accuracy_sum = 0
      loss_sum = 0
      sess_test.run(test_reset)

      for i in range(num_batches_test):
        
        test_metrics_v, test_summary_str_v = sess_test.run(
            [test_metrics, test_summary])
        
        # Update
        accuracy_sum += test_metrics_v['acc']
        loss_sum += test_metrics_v['loss']

        ckpt_num = re.split('-', ckpt)[-1]
        logger.info('TEST ckpt-{}'.format(ckpt_num) 
              + ' bch-{:d}'.format(i) 
              + ' cum_acc: {:.2f}%'.format(accuracy_sum/(i+1)*100) 
              + ' cum_loss: {:.4f}'.format(loss_sum/(i+1)) 
               )

      ave_acc = accuracy_sum / num_batches_test
      ave_loss = loss_sum / num_batches_test
  
      logger.info('TEST ckpt-{}'.format(ckpt_num) 
            + ' avg_acc: {:.2f}%'.format(ave_acc*100) 
            + ' avg_loss: {:.4f}'.format(ave_loss))

      logger.info("Write Test Summary")
      summary_test = tf.Summary()
      summary_test.value.add(tag="test_acc", simple_value=ave_acc)
      summary_test.value.add(tag="test_loss", simple_value=ave_loss)
      summary_writer.add_summary(summary_test, ckpt_num)
Exemple #5
0
def main(args):
    """Run training and validation.
  
  1. Build graphs
      1.1 Training graph to run on multiple GPUs
      1.2 Validation graph to run on multiple GPUs
  2. Configure sessions
      2.1 Train
      2.2 Validate
  3. Main loop
      3.1 Train
      3.2 Write summary
      3.3 Save model
      3.4 Validate model
      
  Author:
    Ashley Gritzman
  """

    # Set reproduciable random seed
    tf.set_random_seed(1234)

    # Directories
    train_dir, train_summary_dir = conf.setup_train_directories()

    # Logger
    conf.setup_logger(logger_dir=train_dir, name="logger_train.txt")

    # Hyperparameters
    conf.load_or_save_hyperparams(train_dir)

    # Get dataset hyperparameters
    logger.info('Using dataset: {}'.format(FLAGS.dataset))
    dataset_size_train = conf.get_dataset_size_train(FLAGS.dataset)
    dataset_size_val = conf.get_dataset_size_validate(FLAGS.dataset)
    build_arch = conf.get_dataset_architecture(FLAGS.dataset)
    num_classes = conf.get_num_classes(FLAGS.dataset)
    create_inputs_train = conf.get_create_inputs(FLAGS.dataset, mode="train")
    create_inputs_val = conf.get_create_inputs(FLAGS.dataset, mode="validate")

    #*****************************************************************************
    # 1. BUILD GRAPHS
    #*****************************************************************************

    #----------------------------------------------------------------------------
    # GRAPH - TRAIN
    #----------------------------------------------------------------------------
    logger.info('BUILD TRAIN GRAPH')
    g_train = tf.Graph()
    with g_train.as_default(), tf.device('/cpu:0'):

        # Get global_step
        global_step = tf.train.get_or_create_global_step()

        # Get batches per epoch
        num_batches_per_epoch = int(dataset_size_train / FLAGS.batch_size)

        # In response to a question on OpenReview, Hinton et al. wrote the
        # following:
        # "We use an exponential decay with learning rate: 3e-3, decay_steps: 20000,     # decay rate: 0.96."
        # https://openreview.net/forum?id=HJWLfGWRb&noteId=ryxTPFDe2X
        lrn_rate = tf.train.exponential_decay(learning_rate=FLAGS.lrn_rate,
                                              global_step=global_step,
                                              decay_steps=20000,
                                              decay_rate=0.96)
        tf.summary.scalar('learning_rate', lrn_rate)
        opt = tf.train.AdamOptimizer(learning_rate=lrn_rate)

        # Get batch from data queue. Batch size is FLAGS.batch_size, which is then
        # divided across multiple GPUs
        input_dict = create_inputs_train()
        batch_x = input_dict['image']
        batch_labels = input_dict['label']

        # AG 03/10/2018: Split batch for multi gpu implementation
        # Each split is of size FLAGS.batch_size / FLAGS.num_gpus
        # See: https://github.com/naturomics/CapsNet-Tensorflow/blob/master/
        # dist_version/distributed_train.py
        splits_x = tf.split(axis=0,
                            num_or_size_splits=FLAGS.num_gpus,
                            value=batch_x)
        splits_labels = tf.split(axis=0,
                                 num_or_size_splits=FLAGS.num_gpus,
                                 value=batch_labels)

        #--------------------------------------------------------------------------
        # MULTI GPU - TRAIN
        #--------------------------------------------------------------------------
        # Calculate the gradients for each model tower
        tower_grads = []
        tower_losses = []
        tower_logits = []
        reuse_variables = None
        for i in range(FLAGS.num_gpus):
            with tf.device('/gpu:%d' % i):
                with tf.name_scope('tower_%d' % i) as scope:
                    logger.info('TOWER %d' % i)
                    #with slim.arg_scope([slim.model_variable, slim.variable],
                    # device='/cpu:0'):
                    with slim.arg_scope([slim.variable], device='/cpu:0'):
                        loss, logits = tower_fn(
                            build_arch,
                            splits_x[i],
                            splits_labels[i],
                            scope,
                            num_classes,
                            reuse_variables=reuse_variables,
                            is_train=True)

                    # Don't reuse variable for first GPU, but do reuse for others
                    reuse_variables = True

                    # Compute gradients for one GPU
                    grads = opt.compute_gradients(loss)

                    # Keep track of the gradients across all towers.
                    tower_grads.append(grads)

                    # Keep track of losses and logits across for each tower
                    tower_logits.append(logits)
                    tower_losses.append(loss)

                    # Loss for each tower
                    tf.summary.scalar("loss", loss)

        # We must calculate the mean of each gradient. Note that this is the
        # synchronization point across all towers.
        grad = average_gradients(tower_grads)

        # See: https://stackoverflow.com/questions/40701712/how-to-check-nan-in-
        # gradients-in-tensorflow-when-updating
        grad_check = ([
            tf.check_numerics(g, message='Gradient NaN Found!')
            for g, _ in grad if g is not None
        ] + [tf.check_numerics(loss, message='Loss NaN Found')])

        # Apply the gradients to adjust the shared variables
        with tf.control_dependencies(grad_check):
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = opt.apply_gradients(grad, global_step=global_step)

        # Calculate mean loss
        loss = tf.reduce_mean(tower_losses)

        # Calculate accuracy
        logits = tf.concat(tower_logits, axis=0)
        acc = met.accuracy(logits, batch_labels)

        # Prepare predictions and one-hot labels
        probs = tf.nn.softmax(logits=logits)
        labels_oh = tf.one_hot(batch_labels, num_classes)

        # Group metrics together
        # See: https://cs230-stanford.github.io/tensorflow-model.html
        trn_metrics = {
            'loss': loss,
            'labels': batch_labels,
            'labels_oh': labels_oh,
            'logits': logits,
            'probs': probs,
            'acc': acc,
        }

        # Reset and read operations for streaming metrics go here
        trn_reset = {}
        trn_read = {}

        # Logging
        tf.summary.scalar('trn_loss', loss)
        tf.summary.scalar('trn_acc', acc)

        # Set Saver
        # AG 26/09/2018: Save all variables including Adam so that we can continue
        # training from where we left off
        # max_to_keep=None should keep all checkpoints
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=None)

        # Display number of parameters
        train_params = np.sum([
            np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()
        ]).astype(np.int32)
        logger.info('Trainable Parameters: {}'.format(train_params))

        # Set summary op
        trn_summary = tf.summary.merge_all()

    #----------------------------------------------------------------------------
    # GRAPH - VALIDATION
    #----------------------------------------------------------------------------
    logger.info('BUILD VALIDATION GRAPH')
    g_val = tf.Graph()
    with g_val.as_default():
        # Get global_step
        global_step = tf.train.get_or_create_global_step()

        num_batches_val = int(dataset_size_val / FLAGS.batch_size *
                              FLAGS.val_prop)

        # Get data
        input_dict = create_inputs_val()
        batch_x = input_dict['image']
        batch_labels = input_dict['label']

        # AG 10/12/2018: Split batch for multi gpu implementation
        # Each split is of size FLAGS.batch_size / FLAGS.num_gpus
        # See: https://github.com/naturomics/CapsNet-
        # Tensorflow/blob/master/dist_version/distributed_train.py
        splits_x = tf.split(axis=0,
                            num_or_size_splits=FLAGS.num_gpus,
                            value=batch_x)
        splits_labels = tf.split(axis=0,
                                 num_or_size_splits=FLAGS.num_gpus,
                                 value=batch_labels)

        #--------------------------------------------------------------------------
        # MULTI GPU - VALIDATE
        #--------------------------------------------------------------------------
        # Calculate the logits for each model tower
        tower_logits = []
        reuse_variables = None
        for i in range(FLAGS.num_gpus):
            with tf.device('/gpu:%d' % i):
                with tf.name_scope('tower_%d' % i) as scope:
                    with slim.arg_scope([slim.variable], device='/cpu:0'):
                        loss, logits = tower_fn(
                            build_arch,
                            splits_x[i],
                            splits_labels[i],
                            scope,
                            num_classes,
                            reuse_variables=reuse_variables,
                            is_train=False)

                    # Don't reuse variable for first GPU, but do reuse for others
                    reuse_variables = True

                    # Keep track of losses and logits across for each tower
                    tower_logits.append(logits)

                    # Loss for each tower
                    tf.summary.histogram("val_logits", logits)

        # Combine logits from all towers
        logits = tf.concat(tower_logits, axis=0)

        # Calculate metrics
        val_loss = mod.spread_loss(logits, batch_labels)
        val_acc = met.accuracy(logits, batch_labels)

        # Prepare predictions and one-hot labels
        val_probs = tf.nn.softmax(logits=logits)
        val_labels_oh = tf.one_hot(batch_labels, num_classes)

        # Group metrics together
        # See: https://cs230-stanford.github.io/tensorflow-model.html
        val_metrics = {
            'loss': val_loss,
            'labels': batch_labels,
            'labels_oh': val_labels_oh,
            'logits': logits,
            'probs': val_probs,
            'acc': val_acc,
        }

        # Reset and read operations for streaming metrics go here
        val_reset = {}
        val_read = {}

        tf.summary.scalar("val_loss", val_loss)
        tf.summary.scalar("val_acc", val_acc)

        # Saver
        saver = tf.train.Saver(max_to_keep=None)

        # Set summary op
        val_summary = tf.summary.merge_all()

    #****************************************************************************
    # 2. SESSIONS
    #****************************************************************************

    #----- SESSION TRAIN -----#
    # Session settings
    sess_train = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                                  log_device_placement=False),
                            graph=g_train)

    # Debugger
    # AG 05/06/2018: Debugging using either command line or TensorBoard
    if FLAGS.debugger is not None:
        # sess = tf_debug.LocalCLIDebugWrapperSession(sess)
        sess_train = tf_debug.TensorBoardDebugWrapperSession(
            sess_train, FLAGS.debugger)

    with g_train.as_default():
        sess_train.run([
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
        ])

        # Restore previous checkpoint
        # AG 26/09/2018: where should this go???
        if FLAGS.load_dir is not None:
            prev_step = load_training(saver, sess_train, FLAGS.load_dir)
        else:
            prev_step = 0

    # Create summary writer, and write the train graph
    summary_writer = tf.summary.FileWriter(train_summary_dir,
                                           graph=sess_train.graph)

    #----- SESSION VALIDATION -----#
    sess_val = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                                log_device_placement=False),
                          graph=g_val)
    with g_val.as_default():
        sess_val.run([
            tf.local_variables_initializer(),
            tf.global_variables_initializer()
        ])

    #****************************************************************************
    # 3. MAIN LOOP
    #****************************************************************************
    SUMMARY_FREQ = 100
    SAVE_MODEL_FREQ = num_batches_per_epoch  # 500
    VAL_FREQ = num_batches_per_epoch  # 500
    PROFILE_FREQ = 5

    for step in range(prev_step, FLAGS.epoch * num_batches_per_epoch + 1):
        #for step in range(0,3):
        # AG 23/05/2018: limit number of iterations for testing
        # for step in range(100):
        epoch_decimal = step / num_batches_per_epoch
        epoch = int(np.floor(epoch_decimal))

        # TF queue would pop batch until no file
        try:
            # TRAIN
            with g_train.as_default():

                # With profiling
                if (FLAGS.profile is True) and ((step % PROFILE_FREQ) == 0):
                    logger.info("Train with Profiling")
                    run_options = tf.RunOptions(
                        trace_level=tf.RunOptions.FULL_TRACE)
                    run_metadata = tf.RunMetadata()
                # Without profiling
                else:
                    run_options = None
                    run_metadata = None

                # Reset streaming metrics
                if step % (num_batches_per_epoch / 4) == 1:
                    logger.info("Reset streaming metrics")
                    sess_train.run([trn_reset])

                # MAIN RUN
                tic = time.time()
                train_op_v, trn_metrics_v, trn_summary_v = sess_train.run(
                    [train_op, trn_metrics, trn_summary],
                    options=run_options,
                    run_metadata=run_metadata)
                toc = time.time()

                # Read streaming metrics
                trn_read_v = sess_train.run(trn_read)

                # Write summary for profiling
                if run_options is not None:
                    summary_writer.add_run_metadata(run_metadata,
                                                    'step{:d}'.format(step))

                # Logging
                logger.info('TRN' + ' e-{:d}'.format(epoch) +
                            ' stp-{:d}'.format(step) +
                            ' {:.2f}s'.format(toc - tic) +
                            ' loss: {:.4f}'.format(trn_metrics_v['loss']) +
                            ' acc: {:.2f}%'.format(trn_metrics_v['acc'] * 100))

        except KeyboardInterrupt:
            sess_train.close()
            sess_val.close()
            sys.exit()

        except tf.errors.InvalidArgumentError as e:
            logger.warning('%d iteration contains NaN gradients. Discard.' %
                           step)
            logger.error(str(e))
            continue

        else:
            # WRITE SUMMARY
            if (step % SUMMARY_FREQ) == 0:
                logger.info("Write Train Summary")
                with g_train.as_default():
                    # Summaries from graph
                    summary_writer.add_summary(trn_summary_v, step)

            # SAVE MODEL
            if (step % SAVE_MODEL_FREQ) == 100:
                logger.info("Save Model")
                with g_train.as_default():
                    train_checkpoint_dir = train_dir + '/checkpoint'
                    if not os.path.exists(train_checkpoint_dir):
                        os.makedirs(train_checkpoint_dir)

                    # Save ckpt from train session
                    ckpt_path = os.path.join(train_checkpoint_dir,
                                             'model.ckpt')
                    saver.save(sess_train, ckpt_path, global_step=step)

            # VALIDATE MODEL
            if (step % VAL_FREQ) == 100:
                #----- Validation -----#
                with g_val.as_default():
                    logger.info("Start Validation")

                    # Restore ckpt to val session
                    latest_ckpt = tf.train.latest_checkpoint(
                        train_checkpoint_dir)
                    saver.restore(sess_val, latest_ckpt)

                    # Reset accumulators
                    accuracy_sum = 0
                    loss_sum = 0
                    sess_val.run(val_reset)

                    for i in range(num_batches_val):
                        val_metrics_v, val_summary_str_v = sess_val.run(
                            [val_metrics, val_summary])

                        # Update
                        accuracy_sum += val_metrics_v['acc']
                        loss_sum += val_metrics_v['loss']

                        # Read
                        val_read_v = sess_val.run(val_read)

                        # Get checkpoint number
                        ckpt_num = re.split('-', latest_ckpt)[-1]

                        # Logging
                        logger.info('VAL ckpt-{}'.format(ckpt_num) +
                                    ' bch-{:d}'.format(i) +
                                    ' cum_acc: {:.2f}%'.format(accuracy_sum /
                                                               (i + 1) * 100) +
                                    ' cum_loss: {:.4f}'.format(loss_sum /
                                                               (i + 1)))

                    # Average across batches
                    ave_acc = accuracy_sum / num_batches_val
                    ave_loss = loss_sum / num_batches_val

                    logger.info('VAL ckpt-{}'.format(ckpt_num) +
                                ' avg_acc: {:.2f}%'.format(ave_acc * 100) +
                                ' avg_loss: {:.4f}'.format(ave_loss))

                    logger.info("Write Val Summary")
                    summary_val = tf.Summary()
                    summary_val.value.add(tag="val_acc", simple_value=ave_acc)
                    summary_val.value.add(tag="val_loss",
                                          simple_value=ave_loss)
                    summary_writer.add_summary(summary_val, step)

    # Close (main loop)
    sess_train.close()
    sess_val.close()
    sys.exit()
    for t in vocabulary:
        if collection_tfs[t] == 0:
            collection_tfs[t] = 1
        prob_t_condit_D = collection_tfs[t] / collection_total_terms
        prob_t_condit_Dq = get_prob_t_condition_Dq(t)
        clt += prob_t_condit_Dq * log(prob_t_condit_Dq / prob_t_condit_D)
    return clt


if __name__ == '__main__':
    c = config.get_paths()
    index_path = c[sys.argv[1]]
    query_file_path = sys.argv[2]
    save_path = sys.argv[3]

    config.setup_logger('querydifficulty')

    ix = index.open_dir(index_path, readonly=True)
    LOGGER.info('Index path: ' + index_path)
    ix_reader = ix.reader()

    vocabulary = []
    db_tfs = defaultdict(int)
    db_total_terms = 0
    with open(c['db_tfs'], 'r') as fr:
        for line in fr:
            parts = line.rsplit(',', 1)
            db_tfs[parts[0]] = int(parts[1])
            db_total_terms += int(parts[1])
            vocabulary.append(parts[0])
Exemple #7
0
def reply(bot, update):
    chat_id = update.message.chat_id
    message = update.message.text

    logger.info(f'Received message from ID {chat_id} - Message: {message}')

    answer = dialogflow_api.get_answer(GOOGLE_PROJECT_ID, chat_id, message,
                                       'ru')
    update.message.reply_text(answer)

    logger.info(f'Message sent to ID {chat_id} - Message: {answer}')


if __name__ == "__main__":
    load_dotenv()
    TELEGRAM_BOT_TOKEN = os.getenv('TELEGRAM_BOT_TOKEN')
    GOOGLE_PROJECT_ID = os.getenv('GOOGLE_PROJECT_ID')

    config.setup_logger(logger)
    logger.info('Бот заработал.')

    updater = Updater(TELEGRAM_BOT_TOKEN)

    dp = updater.dispatcher
    dp.add_handler(CommandHandler('start', start))
    dp.add_handler(CommandHandler('help', reply_to_help))
    dp.add_handler(MessageHandler(Filters.text, reply))

    updater.start_polling()
    updater.idle()
Exemple #8
0
__url__ = "https://www.gnu.org/licenses/agpl-3.0.en.html"
__author__ = "Ludee; christian-rli"
__issue__ = "https://github.com/OpenEnergyPlatform/examples/issues/52"
__version__ = "v0.5.0"

from config import setup_logger  # TODO: doesn't work as desired yet. Please fix! Low Priority!
from mastr_wind_download import download_power_unit, download_unit_wind, download_unit_wind_eeg
from mastr_wind_process import make_wind

import time

"""version"""
DATA_VERSION = '0.8'

"""logging"""
log = setup_logger()
start_time = time.time()
log.info(f'Script started with data version: {DATA_VERSION}.')

"""OEP"""
# metadata = oep_session()

"""MaStR Wind"""
# download_power_unit()
# download_unit_wind()
# download_unit_wind_eeg()
make_wind()

"""close"""
log.info('...Script successfully executed in {:.2f} seconds.'
         .format(time.time() - start_time))
Exemple #9
0
        fc += 1
        if fc % 1000 == 0:
            print(fc, dpath)
        did = config.get_article_id_from_file_name(dname)
        with open(dpath, 'r', encoding='utf-8') as fo:
            dcont = fo.read()
        try:
            if did not in wiki13_title_count:
                raise LookupError(
                    'Filename \'{}\' not in title-count list'.format(did))
            writer.add_document(title=wiki13_title_count[did]['title'],
                                articleID=did,
                                body=dcont,
                                count=wiki13_title_count[did]['count'],
                                xpath=dpath)
        except Exception as e:
            LOGGER.error(dpath + '  ' + str(e) + '\n')

    writer.commit()
    return


if __name__ == '__main__':
    c = config.get_paths()
    if len(sys.argv) >= 4:
        config.setup_logger('{}-{}-{}_build'.format(sys.argv[1], sys.argv[2],
                                                    sys.argv[3]))
        build_index_wiki13(c[sys.argv[1]], c[sys.argv[2]], c[sys.argv[3]])
    else:
        print('dir_alias, index_alias, count_alias is required!')
Exemple #10
0
from http import HTTPStatus
from flask import Flask, jsonify, request

import config
import constants
import decorators

app = Flask(__name__)

logger = config.setup_logger()


@decorators.log(logger, constants.SQUARE)
def square(number):
    value = int(number)

    return value**2


def get_message(name):
    return "Hello " + name


@app.route("/", methods=["POST"])
def message():
    data = request.json

    try:
        logger.debug('Data received: {}'.format(data))
        message = get_message(data['name'])
        logger.debug('Message: {}'.format(message))
def main():
    ''' Main '''

    # no. of samples added and number of samples being collected
    ROWS_ADDED = 0
    NSAMPLES = 1000

    # setup socket
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

    HOST = '127.0.0.1'
    PORT = 5556
    sock.bind((HOST, PORT))
    sock.listen(1)

    landmark_list = landmarkList_pb2.LandmarkList()
    setup_logger()

    logging.info("Waiting for keypoint generator..")
    # Establish connection
    conn, addr = sock.accept()

    actual_hand = int(
        input("Enter the hand for which you are collecting \
    gesture data:\n0) left \t1) right\n"))

    gesture = input(
        "Enter the name of the gesture for which you are capturing data, \
    (a simple one word description of the orientation of your hand) :\n")

    f = open("gestop/data/static_gestures_data.csv", 'a+')
    #set pointer at beginning of file
    f.seek(0)

    # The string which is written to the dataset
    DATASET_STR = ''

    # If the file is empty, add the headers at the top of the file
    if f.read() == '':
        DATASET_STR += dataset_headers()

    while ROWS_ADDED < NSAMPLES:
        data = conn.recv(4096)

        try:
            landmark_list.ParseFromString(data)
        except google.protobuf.message.DecodeError:  # Incorrect data format
            continue
        landmarks = []
        for lmark in landmark_list.landmark:
            landmarks.append({'x': lmark.x, 'y': lmark.y, 'z': lmark.z})

        # Handedness - true if right hand, false if left
        handedness = landmark_list.handedness

        # Add a row to the dataset
        row_str, ROWS_ADDED = add_row(landmarks, handedness, gesture,
                                      actual_hand, ROWS_ADDED)
        DATASET_STR += row_str

        ROWS_ADDED += 1
        #simple loading bar
        print(str(ROWS_ADDED) + '/' + str(NSAMPLES) + '\t|' + ('-' * int(
            (50 * ROWS_ADDED) / NSAMPLES)) + '>',
              end='\r')

    conn.close()
    sock.close()

    # Writing data to file at once instead of in for loop for performance reasons.
    f.write(DATASET_STR)
    f.close()
    logging.info("1000 rows of data has been successfully collected.")
Exemple #12
0
#!/usr/bin/env python3
# -*- coding: utf-8 -*-  

import os
import time

from Logger import Logger
from Setup import ConnectionSetup
from config import settings, setup_logger

log = setup_logger()
lockpath = os.path.join(settings.locks, settings.lockname)


def check_dirs_and_files():
    # log
    if not os.path.exists(settings.logs):
        os.mkdir(settings.logs, 0o000755)
    if not os.path.exists(os.path.join(settings.logs, settings.exceptionlog)):
        file = open(os.path.join(settings.logs, settings.exceptionlog), 'w')
        file.write("<exceptions></exceptions>")
        file.close()
    # lock
    if not os.path.exists(settings.locks):
        os.mkdir(settings.locks, 0o000755)
    # records
    if not os.path.exists(settings.records):
        os.mkdir(settings.records, 0o000755)


def obtain_lock():
Exemple #13
0
def main(args):
    # Set reproduciable random seed
    tf.set_random_seed(1234)

    # Directories
    # Get name
    split = FLAGS.load_dir.split('/')
    if split[-1]:
        name = split[-1]
    else:
        name = split[-2]

    # Get parent directory
    split = FLAGS.load_dir.split("/" + name)
    parent_dir = split[0]

    test_dir = '{}/{}/reconstructions'.format(parent_dir, name)
    test_summary_dir = test_dir + '/summary'

    # Clear the test log directory
    if (FLAGS.reset is True) and os.path.exists(test_dir):
        shutil.rmtree(test_dir)
    if not os.path.exists(test_summary_dir):
        os.makedirs(test_summary_dir)

    # Logger
    conf.setup_logger(logger_dir=test_dir, name="logger_test.txt")
    logger.info("name: " + name)
    logger.info("parent_dir: " + parent_dir)
    logger.info("test_dir: " + test_dir)

    # Load hyperparameters from train run
    conf.load_or_save_hyperparams()

    # Get dataset hyperparameters
    logger.info('Using dataset: {}'.format(FLAGS.dataset))

    # Dataset
    dataset_size_test = conf.get_dataset_size_test(FLAGS.dataset)
    num_classes = conf.get_num_classes(FLAGS.dataset)
    # train mode for random sampling
    create_inputs_test = conf.get_create_inputs(FLAGS.dataset, mode="train")

    # ----------------------------------------------------------------------------
    # GRAPH - TEST
    # ----------------------------------------------------------------------------
    logger.info('BUILD TEST GRAPH')
    g_test = tf.Graph()
    with g_test.as_default():
        tf.train.get_or_create_global_step()
        # Get data
        input_dict = create_inputs_test()
        batch_x = input_dict['image']
        batch_labels = input_dict['label']

        # Build architecture
        build_arch = conf.get_dataset_architecture(FLAGS.dataset)
        # for baseline
        # build_arch = conf.get_dataset_architecture('baseline')

        # --------------------------------------------------------------------------
        # MULTI GPU - TEST
        # --------------------------------------------------------------------------
        # Calculate the logits for each model tower
        with tf.device('/gpu:0'):
            with tf.name_scope('tower_0') as scope:
                with slim.arg_scope([slim.variable], device='/cpu:0'):
                    loss, logits, recon, cf_recon = tower_fn(
                        build_arch,
                        batch_x,
                        batch_labels,
                        scope,
                        num_classes,
                        reuse_variables=tf.AUTO_REUSE,
                        is_train=False)

                # Keep track of losses and logits across for each tower
                recon_images = tf.reshape(recon, batch_x.get_shape())
                cf_recon_images = tf.reshape(cf_recon, batch_x.get_shape())
                images = {
                    "reconstructed_images": recon_images,
                    "reconstructed_images_zeroed_background": cf_recon_images,
                    "input": batch_x
                }
        saver = tf.train.Saver(max_to_keep=None)

        # --------------------------------------------------------------------------
        # SESSION - TEST
        # --------------------------------------------------------------------------
        # sess_test = tf.Session(
        #    config=tf.ConfigProto(allow_soft_placement=True,
        #                          log_device_placement=False),
        #    graph=g_test)
        # Perry: added in for RTX 2070 incompatibility workaround
        config = tf.ConfigProto(allow_soft_placement=True,
                                log_device_placement=False)
        config.gpu_options.allow_growth = True
        sess_test = tf.Session(config=config, graph=g_test)

        # sess_test.run(tf.local_variables_initializer())
        # sess_test.run(tf.global_variables_initializer())

        summary_writer = tf.summary.FileWriter(test_summary_dir,
                                               graph=sess_test.graph)

        ckpts_to_test = []
        load_dir_chechpoint = os.path.join(FLAGS.load_dir, "train",
                                           "checkpoint")

        # Evaluate the latest ckpt in dir
        if FLAGS.ckpt_name is None:
            latest_ckpt = tf.train.latest_checkpoint(load_dir_chechpoint)
            ckpts_to_test.append(latest_ckpt)
        # Evaluate all ckpts in dir
        else:
            ckpt_name = os.path.join(load_dir_chechpoint, FLAGS.ckpt_name)
            ckpts_to_test.append(ckpt_name)

            # --------------------------------------------------------------------------
        # MAIN LOOP
        # --------------------------------------------------------------------------
        # Run testing on checkpoints
        for ckpt in ckpts_to_test:
            saver.restore(sess_test, ckpt)

            for i in range(dataset_size_test):
                out = sess_test.run([images])
                reconstructed_image, reconstructed_image_zeroed_background, input_img =\
                    out[0]["reconstructed_images"], out[0]["reconstructed_images_zeroed_background"], out[0]["input"]
                if reconstructed_image.shape[0] == 1:
                    reconstructed_image = np.squeeze(reconstructed_image,
                                                     axis=0)
                    reconstructed_image_zeroed_background = np.squeeze(
                        reconstructed_image_zeroed_background, axis=0)
                    input_img = np.squeeze(input_img, axis=0)
                if reconstructed_image.shape[-1] == 1:
                    reconstructed_image = np.squeeze(reconstructed_image,
                                                     axis=-1)
                    reconstructed_image_zeroed_background = np.squeeze(
                        reconstructed_image_zeroed_background, axis=-1)
                    input_img = np.squeeze(input_img, axis=-1)
                reconstructed_image = Image.fromarray(
                    (reconstructed_image * 255).astype('uint8'))
                reconstructed_image_zeroed_background = Image.fromarray(
                    (reconstructed_image_zeroed_background *
                     255).astype('uint8'))
                input_img = Image.fromarray((input_img * 255).astype('uint8'))
                fig = plt.figure(figsize=(1, 3))
                fig.add_subplot(1, 3, 1)
                plt.imshow(input_img)
                fig.add_subplot(1, 3, 2)
                plt.imshow(reconstructed_image)
                fig.add_subplot(1, 3, 3)
                plt.imshow(reconstructed_image_zeroed_background)
                plt.show()
    with ix.reader() as ix_reader:
        pa = pt.Partitioner(ix, ix_reader)
        print('Partitioner initiated!')
        parts = pa.generate([0.98, 0.1])
        parts = [p for p in parts]
        print('Parts created!')
        print('naive1 ({}, {})'.format(parts[0].name, parts[1].name))
        sol.generate_distance_distributions(
            cache=parts[0],
            disk=parts[1],
            save_path='/data/khodadaa/index/data',
            distance_type=['kld', 'avg-kld'])


if __name__ == '__main__':
    config.setup_logger('_recursive_enhance')
    # save_dir = sys.argv[1]
    # cache_distribution_path = sys.argv[2]
    # disk_distribution_path = sys.argv[3]
    # config.setup_logger(file_name='_enhance')
    #
    # cache_ranges = [(0.0, 1.0)]
    # disk_ranges = [(0.0, 0.2), (0.0, 1.0)]
    #
    # for c in cache_ranges:
    #     for d in disk_ranges:
    #         sol.naive1(cache_distribution_path=cache_distribution_path, disk_distribution_path=disk_distribution_path,
    #                    save_log_path=save_dir, use_column_with_index=2, cache_start_range=c[0], cache_end_range=c[1],
    #                    disk_start_range=d[0], disk_end_range=d[1], equal_add_delete=True)
    #         # sol.naive2(cache_distribution_path=cache_distribution_path, disk_distribution_path=disk_distribution_path,
    #         #            save_log_path=save_dir, change_fraction=0.17, cache_start_range=c[0], cache_end_range=c[1],