コード例 #1
0
def test_loop(config,
              sess,
              summary_op,
              summary_writer,
              saver,
              restore_saver,
              directories,
              test_dict,
              exp_label,
              num_params,
              log,
              map_out='test_maps',
              num_batches=None,
              placeholders=False,
              checkpoint=None,
              save_weights=False,
              save_checkpoints=False,
              save_activities=False,
              save_gradients=False):
    """Run the model test loop."""
    if checkpoint is not None:
        restore_saver.restore(sess, checkpoint)
        print 'Restored checkpoint %s' % checkpoint
    if placeholders:
        test_images = placeholders['test']['images']
        test_labels = placeholders['test']['labels']
        test_batches = len(test_images) / config.test_batch_size
        test_batch_idx = np.arange(test_batches).reshape(-1, 1).repeat(
            config.test_batch_size)
        test_images = test_images[:len(test_batch_idx)]
        test_labels = test_labels[:len(test_batch_idx)]
        assert len(test_labels), 'Test labels not found.'
        assert len(test_images), 'Test images not found.'

        # Check that labels are appropriate shape
        tf_label_shape = test_dict['test_labels'].get_shape().as_list()
        np_label_shape = test_labels.shape
        if len(tf_label_shape) == 2 and len(np_label_shape) == 1:
            test_labels = test_labels[..., None]
        elif len(tf_label_shape) == len(np_label_shape):
            pass
        elif len(tf_label_shape) == 4 and np.all(
                np.array(tf_label_shape)[1:3] == np_label_shape[1:3]):
            pass
        else:
            raise RuntimeError('Mismatch label shape np: %s vs. tf: %s' %
                               (np_label_shape, tf_label_shape))

        # Loop through all the images
        if num_batches is not None:
            config.validation_steps = num_batches
        else:
            config.validation_steps = test_batches
        test_score, test_lo, it_test_dict, duration = validation_step(
            sequential=True,
            sess=sess,
            val_dict=test_dict,
            config=config,
            log=log,
            dict_image_key='test_images',
            dict_label_key='test_labels',
            eval_score_key='test_score',
            eval_loss_key='test_loss',
            map_im_key='test_proc_images',
            map_log_key='test_logits',
            map_lab_key='test_proc_labels',
            val_images=test_images,
            val_labels=test_labels,
            val_batch_idx=test_batch_idx,
            val_batches=test_batches)
        if hasattr(config,
                   'get_map') and config.get_map and map_out is not None:
            maps, arands = tf_fun.calculate_map(it_test_dict,
                                                exp_label,
                                                config,
                                                map_dir=map_out)
            return {
                'scores': test_score,
                'losses': test_lo,
                'maps': maps,
                'arands': arands,
                'exp_label': exp_label,
                'test_dict': it_test_dict,
                'duration': duration
            }
        else:
            return {
                'scores': test_score,
                'losses': test_lo,
                'exp_label': exp_label,
                'test_dict': it_test_dict,
                'duration': duration
            }
    else:
        test_score, test_lo, it_test_dict, duration = validation_step(
            sess=sess,
            val_dict=test_dict,
            config=config,
            log=log,
            dict_image_key='test_images',
            dict_label_key='test_labels',
            eval_score_key='test_score',
            eval_loss_key='test_loss',
            map_im_key='test_proc_images',
            map_log_key='test_logits',
            map_lab_key='test_proc_labels')
        return {
            'scores': test_score,
            'losses': test_lo,
            'exp_label': exp_label,
            'test_dict': it_test_dict,
            'duration': duration
        }
コード例 #2
0
def training_loop(config,
                  coord,
                  sess,
                  summary_op,
                  summary_writer,
                  saver,
                  restore_saver,
                  threads,
                  directories,
                  train_dict,
                  val_dict,
                  exp_label,
                  num_params,
                  use_db,
                  log,
                  placeholders=False,
                  checkpoint=None,
                  save_weights=False,
                  save_checkpoints=False,
                  save_activities=False,
                  save_gradients=False):
    """Run the model training loop."""
    if checkpoint is not None:
        restore_saver.restore(sess, checkpoint)
        print 'Restored checkpoint %s' % checkpoint
    if not hasattr(config, 'early_stop'):
        config.early_stop = np.inf
    val_perf = np.asarray([np.inf])
    step = 0
    best_val_dict = None
    if save_weights:
        try:
            weight_dict = {v.name: v for v in tf.trainable_variables()}
            val_dict = dict(val_dict, **weight_dict)
        except Exception:
            raise RuntimeError('Failed to find weights to save.')
    else:
        weight_dict = None
    if hasattr(config, 'early_stop'):
        it_early_stop = config.early_stop
    else:
        it_early_stop = np.inf
    if placeholders:
        train_images = placeholders['train']['images']
        val_images = placeholders['val']['images']
        train_labels = placeholders['train']['labels']
        val_labels = placeholders['val']['labels']
        train_batches = len(train_images) / config.train_batch_size
        train_batch_idx = np.arange(train_batches).reshape(-1, 1).repeat(
            config.train_batch_size)
        train_images = train_images[:len(train_batch_idx)]
        train_labels = train_labels[:len(train_batch_idx)]
        val_batches = len(val_images) / config.val_batch_size
        val_batch_idx = np.arange(val_batches).reshape(-1, 1).repeat(
            config.val_batch_size)
        val_images = val_images[:len(val_batch_idx)]
        val_labels = val_labels[:len(val_batch_idx)]

        # Check that labels are appropriate shape
        tf_label_shape = train_dict['train_labels'].get_shape().as_list()
        np_label_shape = train_labels.shape
        if len(tf_label_shape) == 2 and len(np_label_shape) == 1:
            train_labels = train_labels[..., None]
            val_labels = val_labels[..., None]
        elif len(tf_label_shape) == len(np_label_shape):
            pass
        else:
            raise RuntimeError('Mismatch label shape np: %s vs. tf: %s' %
                               (np_label_shape, tf_label_shape))

        # Start training
        for epoch in tqdm(range(config.epochs),
                          desc='Epoch',
                          total=config.epochs):
            for train_batch in range(train_batches):
                data_idx = train_batch_idx == train_batch
                it_train_images = train_images[data_idx]
                it_train_labels = train_labels[data_idx]
                if isinstance(it_train_images[0], basestring):
                    it_train_images = np.asarray([
                        data_to_tfrecords.load_image(im)
                        for im in it_train_images
                    ])
                feed_dict = {
                    train_dict['train_images']: it_train_images,
                    train_dict['train_labels']: it_train_labels
                }
                (train_score, train_loss, it_train_dict,
                 timer) = training_step(sess=sess,
                                        train_dict=train_dict,
                                        feed_dict=feed_dict)
                if step % config.validation_period == 0:
                    val_score, val_lo, it_val_dict, duration = validation_step(
                        sess=sess,
                        val_dict=val_dict,
                        config=config,
                        log=log,
                        val_images=val_images,
                        val_labels=val_labels,
                        val_batch_idx=val_batch_idx,
                        val_batches=val_batches)

                    # Save progress and important data
                    try:
                        val_check = np.where(val_lo < val_perf)[0]
                        if not len(val_check):
                            it_early_stop -= 1
                            print 'Deducted from early stop count.'
                        else:
                            it_early_stop = config.early_stop
                            best_val_dict = it_val_dict
                            print 'Reset early stop count.'
                        if it_early_stop <= 0:
                            print 'Early stop triggered. Ending early.'
                            print 'Best validation loss: %s' % np.min(val_perf)
                            return
                        save_progress(config=config,
                                      val_check=val_check,
                                      weight_dict=weight_dict,
                                      it_val_dict=it_val_dict,
                                      exp_label=exp_label,
                                      step=step,
                                      directories=directories,
                                      sess=sess,
                                      saver=saver,
                                      val_score=val_score,
                                      val_loss=val_lo,
                                      train_score=train_score,
                                      train_loss=train_loss,
                                      timer=duration,
                                      num_params=num_params,
                                      log=log,
                                      summary_op=summary_op,
                                      summary_writer=summary_writer,
                                      save_activities=save_activities,
                                      save_gradients=save_gradients,
                                      save_checkpoints=save_checkpoints)
                    except Exception as e:
                        log.info('Failed to save checkpoint: %s' % e)

                    # Training status and validation accuracy
                    val_status(log=log,
                               dt=datetime.now(),
                               step=step,
                               train_loss=train_loss,
                               rate=config.val_batch_size / duration,
                               timer=float(duration),
                               score_function=config.score_function,
                               train_score=train_score,
                               val_score=val_score,
                               val_loss=val_lo,
                               best_val_loss=np.min(val_perf),
                               summary_dir=directories['summaries'])
                else:
                    # Training status
                    train_status(log=log,
                                 dt=datetime.now(),
                                 step=step,
                                 train_loss=train_loss,
                                 rate=config.val_batch_size / duration,
                                 timer=float(duration),
                                 lr=it_train_dict['lr'],
                                 score_function=config.score_function,
                                 train_score=train_score)

                # End iteration
                val_perf = np.concatenate([val_perf, [val_lo]])
                step += 1

    else:
        try:
            while not coord.should_stop():
                (train_score, train_loss, it_train_dict,
                 duration) = training_step(sess=sess,
                                           config=config,
                                           train_dict=train_dict)
                io_start_time = time.time()
                if step % config.validation_period == 0:
                    val_score, val_lo, it_val_dict, duration = validation_step(
                        sess=sess, val_dict=val_dict, config=config, log=log)

                    # Save progress and important data
                    try:
                        val_check = np.where(val_lo < val_perf)[0]
                        if not len(val_check):
                            it_early_stop -= 1
                            print('Deducted from early stop count (%s/ %s).' %
                                  (it_early_stop, config.early_stop))
                        else:
                            it_early_stop = config.early_stop
                            best_val_dict = it_val_dict
                            print('Reset early stop count (%s/ %s).' %
                                  (it_early_stop, config.early_stop))
                        if it_early_stop <= 0:
                            print 'Early stop triggered. Ending early.'
                            print 'Best validation loss: %s' % np.min(val_perf)
                            break
                        val_perf = save_progress(
                            config=config,
                            val_check=val_check,
                            weight_dict=weight_dict,
                            it_val_dict=it_val_dict,
                            exp_label=exp_label,
                            step=step,
                            directories=directories,
                            sess=sess,
                            saver=saver,
                            val_score=val_score,
                            val_loss=val_lo,
                            val_perf=val_perf,
                            train_score=train_score,
                            train_loss=train_loss,
                            timer=duration,
                            num_params=num_params,
                            log=log,
                            use_db=use_db,
                            summary_op=summary_op,
                            summary_writer=summary_writer,
                            save_activities=save_activities,
                            save_gradients=save_gradients,
                            save_checkpoints=save_checkpoints)
                    except Exception as e:
                        log.info('Failed to save checkpoint: %s' % e)

                    # Training status and validation accuracy
                    val_status(log=log,
                               dt=datetime.now(),
                               step=step,
                               train_loss=train_loss,
                               rate=config.val_batch_size / duration,
                               timer=float(duration),
                               score_function=config.score_function,
                               train_score=train_score,
                               val_score=val_score,
                               val_loss=val_lo,
                               best_val_loss=np.min(val_perf),
                               summary_dir=directories['summaries'])
                else:
                    # Training status
                    io_duration = time.time() - io_start_time
                    train_status(log=log,
                                 dt=datetime.now(),
                                 step=step,
                                 train_loss=train_loss,
                                 rate=config.val_batch_size / duration,
                                 timer=float(duration),
                                 io_timer=float(io_duration),
                                 lr=it_train_dict['lr'],
                                 score_function=config.score_function,
                                 train_score=train_score)

                # End iteration
                step += 1
        except tf.errors.OutOfRangeError:
            log.info('Done training for %d epochs, %d steps.' %
                     (config.epochs, step))
            log.info('Saved to: %s' % directories['checkpoints'])
        finally:
            coord.request_stop()
        coord.join(threads)
        sess.close()
    try:
        print 'Best %s loss: %s' % (config.val_loss_function, val_perf[0])
        if hasattr(config, 'get_map') and config.get_map:
            tf_fun.calculate_map(best_val_dict, exp_label, config)
    except:
        print 'Best %s loss: %s' % val_perf[0]
    return
コード例 #3
0
def training_loop(
        config,
        coord,
        sess,
        summary_op,
        summary_writer,
        saver,
        restore_saver,
        threads,
        directories,
        train_dict,
        val_dict,
        exp_label,
        num_params,
        use_db,
        log,
        placeholders=False,
        checkpoint=None,
        save_weights=False,
        save_checkpoints=False,
        save_activities=False,
        save_gradients=False):
    """Run the model training loop."""
    if checkpoint is not None:
        restore_saver.restore(sess, checkpoint)
        print 'Restored checkpoint %s' % checkpoint
    if not hasattr(config, 'early_stop'):
        config.early_stop = np.inf
    val_perf = np.asarray([np.inf])
    step = 0
    best_val_dict = None
    if save_weights:
        try:
            weight_dict = {v.name: v for v in tf.trainable_variables()}
            val_dict = dict(
                val_dict,
                **weight_dict)
        except Exception:
            raise RuntimeError('Failed to find weights to save.')
    else:
        weight_dict = None
    if hasattr(config, 'early_stop'):
        it_early_stop = config.early_stop
    else:
        it_early_stop = np.inf

    if hasattr(config, "adaptive_train"):
        adaptive_train = config.adaptive_train
    else:
        adaptive_train = False
    if placeholders:
        train_images = placeholders['train']['images']
        val_images = placeholders['val']['images']
        train_labels = placeholders['train']['labels']
        val_labels = placeholders['val']['labels']
        train_batches = len(train_images) / config.train_batch_size
        train_batch_idx = np.arange(
            train_batches).reshape(-1, 1).repeat(
                config.train_batch_size)
        train_images = train_images[:len(train_batch_idx)]
        train_labels = train_labels[:len(train_batch_idx)]
        val_batches = len(val_images) / config.val_batch_size
        val_batch_idx = np.arange(
            val_batches).reshape(-1, 1).repeat(
                config.val_batch_size)
        val_images = val_images[:len(val_batch_idx)]
        val_labels = val_labels[:len(val_batch_idx)]

        # Check that labels are appropriate shape
        tf_label_shape = train_dict['train_labels'].get_shape().as_list()
        np_label_shape = train_labels.shape
        if len(tf_label_shape) == 2 and len(np_label_shape) == 1:
            train_labels = train_labels[..., None]
            val_labels = val_labels[..., None]
        elif len(tf_label_shape) == len(np_label_shape):
            pass
        else:
            raise RuntimeError(
                'Mismatch label shape np: %s vs. tf: %s' % (
                    np_label_shape,
                    tf_label_shape))

        # Start training
        train_losses = []
        train_logits = []
        for epoch in tqdm(
                range(config.epochs),
                desc='Epoch',
                total=config.epochs):
            for train_batch in range(train_batches):
                io_start_time = time.time()
                data_idx = train_batch_idx == train_batch
                it_train_images = train_images[data_idx]
                it_train_labels = train_labels[data_idx]
                if isinstance(it_train_images[0], basestring):
                    it_train_images = np.asarray(
                        [
                            data_to_tfrecords.load_image(im)
                            for im in it_train_images])
                feed_dict = {
                    train_dict['train_images']: it_train_images,
                    train_dict['train_labels']: it_train_labels
                }
                (
                    train_score,
                    train_loss,
                    it_train_dict,
                    timer) = training_step(
                    sess=sess,
                    train_dict=train_dict,
                    config=config,
                    feed_dict=feed_dict)
                train_losses.append(train_loss)
                if step % config.validation_period == 0:
                    val_score, val_lo, it_val_dict, duration = validation_step(
                        sess=sess,
                        val_dict=val_dict,
                        config=config,
                        log=log,
                        val_images=val_images,
                        val_labels=val_labels,
                        val_batch_idx=val_batch_idx,
                        val_batches=val_batches)

                    # Save progress and important data
                    try:
                        val_check = np.where(val_lo < val_perf)[0]
                        if not len(val_check):
                            it_early_stop -= 1
                            print 'Deducted from early stop count.'
                        else:
                            it_early_stop = config.early_stop
                            best_val_dict = it_val_dict
                            print 'Reset early stop count.'
                        if it_early_stop <= 0:
                            print 'Early stop triggered. Ending early.'
                            print 'Best validation loss: %s' % np.min(val_perf)
                            return
                        save_progress(
                            config=config,
                            val_check=val_check,
                            weight_dict=weight_dict,
                            it_val_dict=it_val_dict,
                            exp_label=exp_label,
                            step=step,
                            directories=directories,
                            sess=sess,
                            saver=saver,
                            val_score=val_score,
                            val_loss=val_lo,
                            train_score=train_score,
                            train_loss=train_loss,
                            timer=duration,
                            num_params=num_params,
                            log=log,
                            summary_op=summary_op,
                            summary_writer=summary_writer,
                            save_activities=save_activities,
                            save_gradients=save_gradients,
                            save_checkpoints=save_checkpoints)
                    except Exception as e:
                        log.info('Failed to save checkpoint: %s' % e)

                    # Hack to get the visulations... clean this up later
                    if "BSDS500_test_orientation_viz" in config.experiment:  # .model == "BSDS_inh_perturb" or config.model == "BSDS_exc_perturb":
                        # from matplotlib import pyplot as plt;plt.plot(it_train_dict['train_logits'].squeeze(), "r", label="Perturb");plt.plot(it_train_dict['train_labels'].squeeze()[-6:], 'b', label="GT");plt.legend();plt.show()
                        # from matplotlib import pyplot as plt;plt.imshow((it_train_dict['impatch'].squeeze() + np.asarray([123.68, 116.78, 103.94])[None, None]).astype(np.uint8));plt.show()
                        # from matplotlib import pyplot as plt;dd = it_train_dict["grad0"];plt.imshow(np.abs(dd.squeeze()).mean(-1) / (np.abs(dd.squeeze()).std(-1) + 1e-4));plt.show()
                        # from matplotlib import pyplot as plt;dd = it_train_dict['mask'];plt.imshow(dd.squeeze().mean(-1));plt.show()
                        train_logits.append([it_train_dict["train_logits"].ravel()])
                        out_dir = "circuits_{}".format(config.out_dir)
                        py_utils.make_dir(out_dir)
                        out_target = os.path.join(out_dir, "{}_{}".format(config.model, config.train_dataset))
                        np.save("{}_optim".format(out_target), [sess.run(tf.trainable_variables())])  # , it_train_dict["conv"]])
                        np.save("{}_perf".format(out_target), train_losses)
                        np.save("{}_curves".format(out_target), train_logits)
                        np.save("{}_label".format(out_target), it_train_dict["train_labels"])
                    """
                    if config.model == "BSDS_inh_perturb":
                        np.save("inh_perturbs/optim", sess.run(tf.trainable_variables()[0]))
                        np.save("inh_perturbs/perf", train_losses)
                        np.save("inh_perturbs/curves", train_logits)
                        np.save("inh_perturbs/label", it_train_dict["train_labels"])

                    if config.model == "BSDS_exc_perturb":
                        np.save("exc_perturbs/optim", sess.run(tf.trainable_variables()[0]))
                        np.save("exc_perturbs/perf", train_losses)
                        np.save("exc_perturbs/curves", train_logits)
                        np.save("exc_perturbs/label", it_train_dict["train_labels"])
                    """

                    # Training status and validation accuracy
                    val_status(
                        log=log,
                        dt=datetime.now(),
                        step=step,
                        train_loss=train_loss,
                        rate=config.val_batch_size / duration,
                        timer=float(duration),
                        score_function=config.score_function,
                        train_score=train_score,
                        val_score=val_score,
                        val_loss=val_lo,
                        best_val_loss=np.min(val_perf),
                        summary_dir=directories['summaries'])
                else:
                    # Training status
                    io_duration = time.time() - io_start_time
                    train_status(
                        log=log,
                        dt=datetime.now(),
                        step=step,
                        train_loss=train_loss,
                        rate=config.val_batch_size / duration,
                        timer=float(duration),
                        io_timer=float(io_duration),
                        lr=it_train_dict['lr'],
                        score_function=config.score_function,
                        train_score=train_score)

                # End iteration
                val_perf = np.concatenate([val_perf, [val_lo]])
                step += 1
                
                # Adaptive ending
                if adaptive_train and train_loss <= adaptive_train:
                    break
            if adaptive_train and train_loss <= adaptive_train:
                break


    else:
        try:
            while not coord.should_stop():
                (
                    train_score,
                    train_loss,
                    it_train_dict,
                    duration) = training_step(
                    sess=sess,
                    config=config,
                    train_dict=train_dict)
                io_start_time = time.time()
                if step % config.validation_period == 0:
                    val_score, val_lo, it_val_dict, duration = validation_step(
                        sess=sess,
                        val_dict=val_dict,
                        config=config,
                        log=log)

                    # Save progress and important data
                    try:
                        val_check = np.where(val_lo < val_perf)[0]
                        if not len(val_check):
                            it_early_stop -= 1
                            print 'Deducted from early stop count.'
                        else:
                            it_early_stop = config.early_stop
                            best_val_dict = it_val_dict
                            print 'Reset early stop count.'
                        if it_early_stop <= 0:
                            print 'Early stop triggered. Ending early.'
                            print 'Best validation loss: %s' % np.min(val_perf)
                            break
                        val_perf = save_progress(
                            config=config,
                            val_check=val_check,
                            weight_dict=weight_dict,
                            it_val_dict=it_val_dict,
                            exp_label=exp_label,
                            step=step,
                            directories=directories,
                            sess=sess,
                            saver=saver,
                            val_score=val_score,
                            val_loss=val_lo,
                            val_perf=val_perf,
                            train_score=train_score,
                            train_loss=train_loss,
                            timer=duration,
                            num_params=num_params,
                            log=log,
                            use_db=use_db,
                            summary_op=summary_op,
                            summary_writer=summary_writer,
                            save_activities=save_activities,
                            save_gradients=save_gradients,
                            save_checkpoints=save_checkpoints)
                    except Exception as e:
                        log.info('Failed to save checkpoint: %s' % e)

                    # Training status and validation accuracy
                    val_status(
                        log=log,
                        dt=datetime.now(),
                        step=step,
                        train_loss=train_loss,
                        rate=config.val_batch_size / duration,
                        timer=float(duration),
                        score_function=config.score_function,
                        train_score=train_score,
                        val_score=val_score,
                        val_loss=val_lo,
                        best_val_loss=np.min(val_perf),
                        summary_dir=directories['summaries'])
                else:
                    # Training status
                    io_duration = time.time() - io_start_time
                    train_status(
                        log=log,
                        dt=datetime.now(),
                        step=step,
                        train_loss=train_loss,
                        rate=config.val_batch_size / duration,
                        timer=float(duration),
                        io_timer=float(io_duration),
                        lr=it_train_dict['lr'],
                        score_function=config.score_function,
                        train_score=train_score)

                # End iteration
                step += 1
        except tf.errors.OutOfRangeError:
            log.info(
                'Done training for %d epochs, %d steps.' % (
                    config.epochs, step))
            log.info('Saved to: %s' % directories['checkpoints'])
        finally:
            coord.request_stop()
        coord.join(threads)
        sess.close()
    print 'Best %s loss: %s' % (config.val_loss_function, val_perf[0])
    if hasattr(config, 'get_map') and config.get_map:
        tf_fun.calculate_map(best_val_dict, exp_label, config)
    return