예제 #1
0
class CustomPythonFormatter(PythonFormatter):
    """Logging Formatter to add colors and custom format string

    adapted from https://stackoverflow.com/q/64263526 (author: MaKaNu)
    """

    blue = "\x1b[34;21m"
    green = "\x1b[32;21m"
    yellow = "\x1b[33;21m"
    red = "\x1b[31;21m"
    bold_red = "\x1b[31;1m"
    grey = "\x1b[38;5;246m"
    reset_color = "\x1b[0m"
    format_str = "%(levelname)s" + reset_color + " - %(funcName)s()\n »»» " + grey + \
                 " %(message)s \n" + reset_color
    date_format = "%Y-%m-%d %H:%M:%S"

    FORMATS = {
        "DEBUG":
        logging.PythonFormatter(fmt=blue + format_str, datefmt=date_format),
        "INFO":
        logging.PythonFormatter(fmt=green + format_str, datefmt=date_format),
        "WARNING":
        logging.PythonFormatter(fmt=yellow + format_str, datefmt=date_format),
        "ERROR":
        logging.PythonFormatter(fmt=red + format_str, datefmt=date_format),
        "FATAL":
        logging.PythonFormatter(fmt=bold_red + format_str, datefmt=date_format)
    }

    def format(self, record: LogRecord):
        formatter = self.FORMATS.get(record.levelname, self.FORMATS["INFO"])
        return formatter.format(record)
예제 #2
0
 def setUp(self):
   self.now_tuple = time.localtime(time.mktime(
       (1979, 10, 21, 18, 17, 16, 3, 15, 1)))
   self.new_prefix = lambda level: '(blah_prefix)'
   mock.patch.object(time, 'time').start()
   mock.patch.object(time, 'localtime').start()
   self.record = std_logging.LogRecord(
       'name', std_logging.INFO, 'path', 12, 'A Message', [], False)
   self.formatter = logging.PythonFormatter()
예제 #3
0
def add_log_file(logfile):
  """Replicate logs to an additional logfile.

  The caller is responsible for closing the logfile.
  Args:
    logfile: Open file to write log to
  """
  handler = logging.StreamHandler(logfile)
  handler.setFormatter(absl_logging.PythonFormatter())

  absl_logger = logging.getLogger('absl')
  absl_logger.addHandler(handler)
def main(argv):
    fmt = '[%(filename)s:%(lineno)s] %(message)s'
    formatter = logging.PythonFormatter(fmt)
    logging.get_absl_handler().setFormatter(formatter)
    del argv  # unused arg

    tf.io.gfile.makedirs(FLAGS.output_dir)
    logging.info('Saving checkpoints at %s', FLAGS.output_dir)
    tf.random.set_seed(FLAGS.seed)

    data_dir = None
    if FLAGS.use_gpu:
        logging.info('Use GPU')
        strategy = tf.distribute.MirroredStrategy()
    else:
        logging.info('Use TPU at %s',
                     FLAGS.tpu if FLAGS.tpu is not None else 'local')
        data_dir = FLAGS.data_dir
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=FLAGS.tpu)
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.TPUStrategy(resolver)

    ds_info = tfds.builder(FLAGS.dataset).info
    batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores
    train_dataset_size = (ds_info.splits['train'].num_examples *
                          FLAGS.train_proportion)
    steps_per_epoch = int(train_dataset_size / batch_size)
    logging.info('Steps per epoch %s', steps_per_epoch)
    logging.info('Size of the dataset %s',
                 ds_info.splits['train'].num_examples)
    logging.info('Train proportion %s', FLAGS.train_proportion)
    steps_per_eval = ds_info.splits['test'].num_examples // batch_size
    num_classes = ds_info.features['label'].num_classes

    aug_params = {
        'augmix': FLAGS.augmix,
        'aug_count': FLAGS.aug_count,
        'augmix_depth': FLAGS.augmix_depth,
        'augmix_prob_coeff': FLAGS.augmix_prob_coeff,
        'augmix_width': FLAGS.augmix_width,
    }

    # Note that stateless_{fold_in,split} may incur a performance cost, but a
    # quick side-by-side test seemed to imply this was minimal.
    seeds = tf.random.experimental.stateless_split(
        [FLAGS.seed, FLAGS.seed + 1], 2)[:, 0]
    train_builder = ub.datasets.get(FLAGS.dataset,
                                    download_data=FLAGS.download_data,
                                    split=tfds.Split.TRAIN,
                                    seed=seeds[0],
                                    aug_params=aug_params,
                                    validation_percent=1. -
                                    FLAGS.train_proportion,
                                    data_dir=data_dir)
    train_dataset = train_builder.load(batch_size=batch_size)
    validation_dataset = None
    steps_per_validation = 0
    if FLAGS.train_proportion < 1.0:
        validation_builder = ub.datasets.get(FLAGS.dataset,
                                             split=tfds.Split.VALIDATION,
                                             validation_percent=1. -
                                             FLAGS.train_proportion,
                                             data_dir=data_dir)
        validation_dataset = validation_builder.load(batch_size=batch_size)
        validation_dataset = strategy.experimental_distribute_dataset(
            validation_dataset)
        steps_per_validation = validation_builder.num_examples // batch_size
    clean_test_builder = ub.datasets.get(FLAGS.dataset,
                                         split=tfds.Split.TEST,
                                         data_dir=data_dir)
    clean_test_dataset = clean_test_builder.load(batch_size=batch_size)
    train_dataset = strategy.experimental_distribute_dataset(train_dataset)
    test_datasets = {
        'clean': strategy.experimental_distribute_dataset(clean_test_dataset),
    }
    steps_per_epoch = train_builder.num_examples // batch_size
    steps_per_eval = clean_test_builder.num_examples // batch_size
    num_classes = 100 if FLAGS.dataset == 'cifar100' else 10
    if FLAGS.corruptions_interval > 0:
        if FLAGS.dataset == 'cifar100':
            data_dir = FLAGS.cifar100_c_path
        corruption_types, _ = utils.load_corrupted_test_info(FLAGS.dataset)
        for corruption_type in corruption_types:
            for severity in range(1, 6):
                dataset = ub.datasets.get(
                    f'{FLAGS.dataset}_corrupted',
                    corruption_type=corruption_type,
                    severity=severity,
                    split=tfds.Split.TEST,
                    data_dir=data_dir).load(batch_size=batch_size)
                test_datasets[f'{corruption_type}_{severity}'] = (
                    strategy.experimental_distribute_dataset(dataset))

    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.output_dir, 'summaries'))

    with strategy.scope():
        logging.info('Building ResNet model')
        model = ub.models.wide_resnet(input_shape=(32, 32, 3),
                                      depth=28,
                                      width_multiplier=10,
                                      num_classes=num_classes,
                                      l2=FLAGS.l2,
                                      hps=_extract_hyperparameter_dictionary(),
                                      seed=seeds[1],
                                      version=2)
        logging.info('Model input shape: %s', model.input_shape)
        logging.info('Model output shape: %s', model.output_shape)
        logging.info('Model number of weights: %s', model.count_params())
        # Linearly scale learning rate and the decay epochs by vanilla settings.
        base_lr = FLAGS.base_learning_rate * batch_size / 128
        lr_decay_epochs = [(int(start_epoch_str) * FLAGS.train_epochs) // 200
                           for start_epoch_str in FLAGS.lr_decay_epochs]
        lr_schedule = ub.schedules.WarmUpPiecewiseConstantSchedule(
            steps_per_epoch,
            base_lr,
            decay_ratio=FLAGS.lr_decay_ratio,
            decay_epochs=lr_decay_epochs,
            warmup_epochs=FLAGS.lr_warmup_epochs)
        optimizer = tf.keras.optimizers.SGD(lr_schedule,
                                            momentum=1.0 -
                                            FLAGS.one_minus_momentum,
                                            nesterov=True)
        metrics = {
            'train/negative_log_likelihood':
            tf.keras.metrics.Mean(),
            'train/accuracy':
            tf.keras.metrics.SparseCategoricalAccuracy(),
            'train/loss':
            tf.keras.metrics.Mean(),
            'train/ece':
            rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
            'test/negative_log_likelihood':
            tf.keras.metrics.Mean(),
            'test/accuracy':
            tf.keras.metrics.SparseCategoricalAccuracy(),
            'test/ece':
            rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
        }
        if validation_dataset:
            metrics.update({
                'validation/negative_log_likelihood':
                tf.keras.metrics.Mean(),
                'validation/accuracy':
                tf.keras.metrics.SparseCategoricalAccuracy(),
                'validation/ece':
                rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
            })
        if FLAGS.corruptions_interval > 0:
            corrupt_metrics = {}
            for intensity in range(1, 6):
                for corruption in corruption_types:
                    dataset_name = '{0}_{1}'.format(corruption, intensity)
                    corrupt_metrics['test/nll_{}'.format(dataset_name)] = (
                        tf.keras.metrics.Mean())
                    corrupt_metrics['test/accuracy_{}'.format(
                        dataset_name)] = (
                            tf.keras.metrics.SparseCategoricalAccuracy())
                    corrupt_metrics['test/ece_{}'.format(dataset_name)] = (
                        rm.metrics.ExpectedCalibrationError(
                            num_bins=FLAGS.num_bins))

        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
        initial_epoch = 0
        if latest_checkpoint:
            # checkpoint.restore must be within a strategy.scope() so that optimizer
            # slot variables are mirrored.
            checkpoint.restore(latest_checkpoint)
            logging.info('Loaded checkpoint %s', latest_checkpoint)
            initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

    @tf.function
    def train_step(iterator):
        """Training StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images = inputs['features']
            labels = inputs['labels']

            if FLAGS.augmix and FLAGS.aug_count >= 1:
                # Index 0 at augmix processing is the unperturbed image.
                # We take just 1 augmented image from the returned augmented images.
                images = images[:, 1, ...]
            with tf.GradientTape() as tape:
                logits = model(images, training=True)
                if FLAGS.label_smoothing == 0.:
                    negative_log_likelihood = tf.reduce_mean(
                        tf.keras.losses.sparse_categorical_crossentropy(
                            labels, logits, from_logits=True))
                else:
                    one_hot_labels = tf.one_hot(tf.cast(labels, tf.int32),
                                                num_classes)
                    negative_log_likelihood = tf.reduce_mean(
                        tf.keras.losses.categorical_crossentropy(
                            one_hot_labels,
                            logits,
                            from_logits=True,
                            label_smoothing=FLAGS.label_smoothing))
                l2_loss = sum(model.losses)
                loss = negative_log_likelihood + l2_loss
                # Scale the loss given the TPUStrategy will reduce sum all gradients.
                scaled_loss = loss / strategy.num_replicas_in_sync

            grads = tape.gradient(scaled_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

            probs = tf.nn.softmax(logits)
            metrics['train/ece'].add_batch(probs, label=labels)
            metrics['train/loss'].update_state(loss)
            metrics['train/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['train/accuracy'].update_state(labels, logits)

        for _ in tf.range(tf.cast(steps_per_epoch, tf.int32)):
            strategy.run(step_fn, args=(next(iterator), ))

    @tf.function
    def test_step(iterator, dataset_split, dataset_name, num_steps):
        """Evaluation StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images = inputs['features']
            labels = inputs['labels']
            logits = model(images, training=False)
            probs = tf.nn.softmax(logits)
            negative_log_likelihood = tf.reduce_mean(
                tf.keras.losses.sparse_categorical_crossentropy(labels, probs))

            if dataset_name == 'clean':
                metrics[
                    f'{dataset_split}/negative_log_likelihood'].update_state(
                        negative_log_likelihood)
                metrics[f'{dataset_split}/accuracy'].update_state(
                    labels, probs)
                metrics[f'{dataset_split}/ece'].add_batch(probs, label=labels)
            else:
                corrupt_metrics['test/nll_{}'.format(
                    dataset_name)].update_state(negative_log_likelihood)
                corrupt_metrics['test/accuracy_{}'.format(
                    dataset_name)].update_state(labels, probs)
                corrupt_metrics['test/ece_{}'.format(dataset_name)].add_batch(
                    probs, label=labels)

        for _ in tf.range(tf.cast(num_steps, tf.int32)):
            strategy.run(step_fn, args=(next(iterator), ))

    metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()})
    metrics.update({'train/ms_per_example': tf.keras.metrics.Mean()})

    train_iterator = iter(train_dataset)
    start_time = time.time()
    tb_callback = None
    if FLAGS.collect_profile:
        tb_callback = tf.keras.callbacks.TensorBoard(profile_batch=(100, 102),
                                                     log_dir=os.path.join(
                                                         FLAGS.output_dir,
                                                         'logs'))
        tb_callback.set_model(model)
    for epoch in range(initial_epoch, FLAGS.train_epochs):
        logging.info('Starting to run epoch: %s', epoch)
        if tb_callback:
            tb_callback.on_epoch_begin(epoch)
        train_start_time = time.time()
        train_step(train_iterator)
        ms_per_example = (time.time() - train_start_time) * 1e6 / batch_size
        metrics['train/ms_per_example'].update_state(ms_per_example)

        current_step = (epoch + 1) * steps_per_epoch
        max_steps = steps_per_epoch * FLAGS.train_epochs
        time_elapsed = time.time() - start_time
        steps_per_sec = float(current_step) / time_elapsed
        eta_seconds = (max_steps - current_step) / steps_per_sec
        message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                   'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                       current_step / max_steps, epoch + 1, FLAGS.train_epochs,
                       steps_per_sec, eta_seconds / 60, time_elapsed / 60))
        logging.info(message)
        if tb_callback:
            tb_callback.on_epoch_end(epoch)

        if validation_dataset:
            validation_iterator = iter(validation_dataset)
            test_step(validation_iterator, 'validation', 'clean',
                      steps_per_validation)
        datasets_to_evaluate = {'clean': test_datasets['clean']}
        if (FLAGS.corruptions_interval > 0
                and (epoch + 1) % FLAGS.corruptions_interval == 0):
            datasets_to_evaluate = test_datasets
        for dataset_name, test_dataset in datasets_to_evaluate.items():
            test_iterator = iter(test_dataset)
            logging.info('Testing on dataset %s', dataset_name)
            logging.info('Starting to run eval at epoch: %s', epoch)
            test_start_time = time.time()
            test_step(test_iterator, 'test', dataset_name, steps_per_eval)
            ms_per_example = (time.time() - test_start_time) * 1e6 / batch_size
            metrics['test/ms_per_example'].update_state(ms_per_example)

            logging.info('Done with testing on %s', dataset_name)

        corrupt_results = {}
        if (FLAGS.corruptions_interval > 0
                and (epoch + 1) % FLAGS.corruptions_interval == 0):
            corrupt_results = utils.aggregate_corrupt_metrics(
                corrupt_metrics, corruption_types)

        logging.info('Train Loss: %.4f, Accuracy: %.2f%%',
                     metrics['train/loss'].result(),
                     metrics['train/accuracy'].result() * 100)
        logging.info('Test NLL: %.4f, Accuracy: %.2f%%',
                     metrics['test/negative_log_likelihood'].result(),
                     metrics['test/accuracy'].result() * 100)
        total_results = {
            name: metric.result()
            for name, metric in metrics.items()
        }
        total_results.update(corrupt_results)
        # Metrics from Robustness Metrics (like ECE) will return a dict with a
        # single key/value, instead of a scalar.
        total_results = {
            k: (list(v.values())[0] if isinstance(v, dict) else v)
            for k, v in total_results.items()
        }
        with summary_writer.as_default():
            for name, result in total_results.items():
                tf.summary.scalar(name, result, step=epoch + 1)

        for metric in metrics.values():
            metric.reset_states()

        if (FLAGS.checkpoint_interval > 0
                and (epoch + 1) % FLAGS.checkpoint_interval == 0):
            checkpoint_name = checkpoint.save(
                os.path.join(FLAGS.output_dir, 'checkpoint'))
            logging.info('Saved checkpoint to %s', checkpoint_name)

    final_checkpoint_name = checkpoint.save(
        os.path.join(FLAGS.output_dir, 'checkpoint'))
    logging.info('Saved last checkpoint to %s', final_checkpoint_name)
    with summary_writer.as_default():
        hp.hparams({
            'base_learning_rate': FLAGS.base_learning_rate,
            'one_minus_momentum': FLAGS.one_minus_momentum,
            'l2': FLAGS.l2,
        })
예제 #5
0
 def setUp(self):
   formatter = logging.PythonFormatter()
   self.absl_handler = logging.ABSLHandler(formatter)
예제 #6
0
def add_gfile_logger(workdir, *, basename='train', level=python_logging.INFO):
    """Adds GFile file logger to Python logging handlers."""
    fh = GFileHandler(f'{workdir}/{basename}.log', 'a')
    fh.setLevel(level)
    fh.setFormatter(logging.PythonFormatter())
    python_logging.getLogger('').addHandler(fh)