def main(_=None):
    FLAGS = flags.FLAGS  # pylint: disable=invalid-name,redefined-outer-name
    config = FLAGS
    FLAGS.__dict__['config'] = config

    # Build the graph
    with tf.Graph().as_default():

        model_dict = model_config.get(FLAGS)
        data_dict = data_config.get(FLAGS)

        model = model_dict.model
        trainset = data_dict.trainset
        validset = data_dict.validset

        # Optimisation target
        validset = tools.maybe_convert_dataset(validset)
        trainset = tools.maybe_convert_dataset(trainset)

        t1 = model(trainset)
        t2 = model(validset)

        sess = tf.Session()
        saver = tf.train.Saver()
        saver.restore(sess, FLAGS.snapshot)

    if config.dataset == 'mnist':
        _collect_results(sess, _capsule.save_op, validset,
                         10000 // FLAGS.batch_size)
        _collect_results(sess, _capsule.save_op, trainset,
                         60000 // FLAGS.batch_size)
    elif config.dataset == 'svhn':
        _collect_results(sess, _capsule.save_op, validset,
                         26032 // FLAGS.batch_size)

        _collect_results(sess, _capsule.save_op, trainset,
                         73257 // FLAGS.batch_size)
    elif config.dataset == 'cifar10':
        _collect_results(sess, _capsule.save_op, validset,
                         10000 // FLAGS.batch_size)

        _collect_results(sess, _capsule.save_op, trainset,
                         50000 // FLAGS.batch_size)
Exemple #2
0
def main(_=None):
  FLAGS = flags.FLAGS  # pylint: disable=invalid-name,redefined-outer-name
  config = FLAGS
  FLAGS.__dict__['config'] = config

  FLAGS.logdir = FLAGS.logdir.format(name=FLAGS.name)

  logdir = FLAGS.logdir
  logging.info('logdir: %s', logdir)

  if os.path.exists(logdir) and FLAGS.overwrite:
    logging.info('"overwrite" is set to True. Deleting logdir at "%s".', logdir)
    shutil.rmtree(logdir)

  # Build the graph
  with tf.Graph().as_default():

    model_dict = model_config.get(FLAGS)
    data_dict = data_config.get(FLAGS)

    lr = model_dict.lr
    opt = model_dict.opt
    model = model_dict.model
    trainset = data_dict.trainset
    validset = data_dict.validset

    lr = tf.convert_to_tensor(lr)
    tf.summary.scalar('learning_rate', lr)

    # Training setup
    global_step = tf.train.get_or_create_global_step()

    # Optimisation target
    validset = tools.maybe_convert_dataset(validset)
    trainset = tools.maybe_convert_dataset(trainset)
    target, gvs = model.make_target(trainset, opt)

    if gvs is None:
      gvs = opt.compute_gradients(target)

    suppress_inf_and_nans = (config.grad_value_clip > 0
                             or config.grad_norm_clip > 0)
    report = tools.gradient_summaries(gvs, suppress_inf_and_nans)
    report['target'] = target
    valid_report = dict()

    gvs = tools.clip_gradients(gvs, value_clip=config.grad_value_clip,
                               norm_clip=config.grad_norm_clip)

    try:
      report.update(model.make_report(trainset))
      valid_report.update(model.make_report(validset))
    except AttributeError:
      logging.warning('Model %s has no "make_report" method.', str(model))
      raise

    plot_dict, plot_params = None, None
    if config.plot:
      try:
        plot_dict, plot_params = model.make_plot(trainset, 'train')
        valid_plot, valid_params = model.make_plot(validset, 'valid')

        plot_dict.update(valid_plot)
        if plot_params is not None:
          plot_params.update(valid_params)

      except AttributeError:
        logging.warning('Model %s has no "make_plot" method.', str(model))

    report = tools.scalar_logs(report, config.ema, 'train',
                               global_update=config.global_ema_update)
    report['lr'] = lr
    valid_report = tools.scalar_logs(
        valid_report, config.ema, 'valid',
        global_update=config.global_ema_update)

    reports_keys = sorted(report.keys())

    def _format(k):
      if k in ('lr', 'learning_rate'):
        return '.2E'
      return '.3f'

    report_template = ', '.join(['{}: {}{}:{}{}'.format(
        k, '{', k, _format(k), '}') for k in reports_keys])

    logging.info('Trainable variables:')
    tools.log_variables_by_scope()

    # inspect gradients
    for g, v in gvs:
      if g is None:
        logging.warning('No gradient for variable: %s.', v.name)

    tools.log_num_params()

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    if FLAGS.check_numerics:
      update_ops += [tf.add_check_numerics_ops()]

    with tf.control_dependencies(update_ops):
      train_step = opt.apply_gradients(gvs, global_step=global_step)

    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True

    with tf.train.SingularMonitoredSession(
        hooks=create_hooks(FLAGS, plot_dict, plot_params),
        checkpoint_dir=logdir, config=sess_config) as sess:

      train_itr, _ = sess.run([global_step, update_ops])
      train_tensors = [global_step, train_step]
      report_tensors = [report, valid_report]
      all_tensors = report_tensors + train_tensors

      while train_itr < config.max_train_steps:

        if train_itr % config.report_loss_steps == 0:
          report_vals, valid_report_vals, train_itr, _ = sess.run(all_tensors)

          logging.info('')
          logging.info('train:')
          logging.info('#%s: %s', train_itr,
                       report_template.format(**report_vals))

          logging.info('valid:')
          valid_logs = dict(report_vals)
          valid_logs.update(valid_report_vals)
          logging.info('#%s: %s', train_itr,
                       report_template.format(**valid_logs))

          vals_to_check = list(report_vals.values())
          if (np.isnan(vals_to_check).any()
              or np.isnan(vals_to_check).any()):
            logging.fatal('NaN in reports: %s; breaking...',
                          report_template.format(**report_vals))

        else:
          train_itr, _ = sess.run(train_tensors)
Exemple #3
0
def main(_=None):
  FLAGS = flags.FLAGS  # pylint: disable=invalid-name,redefined-outer-name
  config = FLAGS
  FLAGS.__dict__['config'] = config

  # Build the graph
  with tf.Graph().as_default():

    model_dict = model_config.get(FLAGS)
    data_dict = data_config.get(FLAGS)

    model = model_dict.model
    trainset = data_dict.trainset
    validset = data_dict.validset

    # Optimisation target
    validset = tools.maybe_convert_dataset(validset)
    trainset = tools.maybe_convert_dataset(trainset)

    train_tensors = model(trainset)
    valid_tensors = model(validset)

    sess = tf.Session()
    saver = tf.train.Saver()
    saver.restore(sess, FLAGS.snapshot)

  valid_results = _collect_results(sess, valid_tensors, validset,
                                   10000 // FLAGS.batch_size)

  train_results = _collect_results(sess, train_tensors, trainset,
                                   60000 // FLAGS.batch_size)

  results = AttrDict(train=train_results, valid=valid_results)

  # Linear classification
  print('Linear classification accuracy:')
  for k, v in results.items():
    print('\t{}: prior={:.04f}, posterior={:.04f}'.format(
        k, v.prior_acc.mean(), v.posterior_acc.mean()))

  # Unsupervised classification via clustering
  print('Bipartite matching classification accuracy:')
  for field in 'posterior_pres prior_pres'.split():
    kmeans = sklearn.cluster.KMeans(
        n_clusters=10,
        precompute_distances=True,
        n_jobs=-1,
        max_iter=1000,
    ).fit(results.train[field])

    train_acc = cluster_classify(results.train[field], results.train.label, 10,
                                 kmeans)
    valid_acc = cluster_classify(results.valid[field], results.valid.label, 10,
                                 kmeans)

    print('\t{}: train_acc={:.04f}, valid_acc={:.04f}'.format(field, train_acc,
                                                              valid_acc))

  checkpoint_folder = osp.dirname(FLAGS.snapshot)
  figure_filename = osp.join(checkpoint_folder, FLAGS.tsne_figure_name)
  print('Savign TSNE plot at "{}"'.format(figure_filename))
  make_tsne_plot(valid_results.posterior_pres, valid_results.label,
                 figure_filename)
def main(_=None):
    FLAGS = flags.FLAGS  # pylint: disable=invalid-name,redefined-outer-name
    config = FLAGS
    FLAGS.__dict__['config'] = config

    FLAGS.logdir = FLAGS.logdir.format(name=FLAGS.name)

    logdir = FLAGS.logdir
    logging.info('logdir: %s', logdir)

    if os.path.exists(logdir) and FLAGS.overwrite:
        logging.info('"overwrite" is set to True. Deleting logdir at "%s".',
                     logdir)
        shutil.rmtree(logdir)

    fig_time = np.zeros([600])
    fig_train = np.zeros([600])
    fig_valid = np.zeros([600])
    # Build the graph
    with tf.Graph().as_default():

        model_dict = model_config.get(FLAGS)
        data_dict = data_config.get(FLAGS)

        lr = model_dict.lr
        opt = model_dict.opt
        model = model_dict.model
        trainset = data_dict.trainset
        validset = data_dict.validset

        lr = tf.convert_to_tensor(lr)
        tf.summary.scalar('learning_rate', lr)

        # Training setup
        global_step = tf.train.get_or_create_global_step()

        # Optimisation target
        validset = tools.maybe_convert_dataset(validset)
        trainset = tools.maybe_convert_dataset(trainset)
        target, gvs = model.make_target(trainset, opt)

        if gvs is None:
            print('gvs is none')
            gvs = opt.compute_gradients(target)

        suppress_inf_and_nans = (config.grad_value_clip > 0
                                 or config.grad_norm_clip > 0)
        report = tools.gradient_summaries(gvs, suppress_inf_and_nans)
        report['target'] = target
        valid_report = dict()

        gvs = tools.clip_gradients(gvs,
                                   value_clip=config.grad_value_clip,
                                   norm_clip=config.grad_norm_clip)

        try:
            report.update(model.make_report(trainset))
            valid_report.update(model.make_report(validset))
        except AttributeError:
            logging.warning('Model %s has no "make_report" method.',
                            str(model))
            raise

        plot_dict, plot_params = None, None
        if config.plot:
            try:
                plot_dict, plot_params = model.make_plot(trainset, 'train')
                valid_plot, valid_params = model.make_plot(validset, 'valid')

                plot_dict.update(valid_plot)
                if plot_params is not None:
                    plot_params.update(valid_params)

            except AttributeError:
                logging.warning('Model %s has no "make_plot" method.',
                                str(model))

        report = tools.scalar_logs(report,
                                   config.ema,
                                   'train',
                                   global_update=config.global_ema_update)
        report['lr'] = lr
        valid_report = tools.scalar_logs(
            valid_report,
            config.ema,
            'valid',
            global_update=config.global_ema_update)

        reports_keys = sorted(report.keys())

        def _format(k):
            if k in ('lr', 'learning_rate'):
                return '.2E'
            return '.3f'

        report_template = ', '.join([
            '{}: {}{}:{}{}'.format(k, '{', k, _format(k), '}')
            for k in reports_keys
        ])

        logging.info('Trainable variables:')
        tools.log_variables_by_scope()

        # inspect gradients
        for g, v in gvs:
            if g is None:
                logging.warning('No gradient for variable: %s.', v.name)

        tools.log_num_params()

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        if FLAGS.check_numerics:
            update_ops += [tf.add_check_numerics_ops()]

        with tf.control_dependencies(update_ops):
            train_step = opt.apply_gradients(gvs, global_step=global_step)

        sess_config = tf.ConfigProto()
        sess_config.gpu_options.allow_growth = True

        #run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        #run_metadata = tf.RunMetadata()

        with tf.train.SingularMonitoredSession(hooks=create_hooks(
                FLAGS, plot_dict, plot_params),
                                               checkpoint_dir=logdir,
                                               config=sess_config) as sess:

            train_itr, _ = sess.run([global_step, update_ops])
            train_tensors = [global_step, train_step]
            report_tensors = [report, valid_report]
            all_tensors = report_tensors + train_tensors

            start_time = time.time()
            while train_itr < config.max_train_steps:

                if train_itr % config.report_loss_steps == 0:

                    i = (int)(train_itr / config.report_loss_steps)
                    report_vals, valid_report_vals, train_itr, _ = sess.run(
                        all_tensors)
                    logging.info('')
                    logging.info('train:')
                    logging.info('#%s: %s', train_itr,
                                 report_template.format(**report_vals))

                    end_time = time.time()
                    fig_time[i] = end_time - start_time

                    logging.info('valid:')
                    valid_logs = dict(report_vals)
                    fig_train[i] = valid_logs['best_cls_acc']
                    valid_logs.update(valid_report_vals)
                    logging.info('#%s: %s', train_itr,
                                 report_template.format(**valid_logs))
                    fig_valid[i] = valid_logs['best_cls_acc']

                    vals_to_check = list(report_vals.values())

                    if (np.isnan(vals_to_check).any()
                            or np.isnan(vals_to_check).any()):
                        logging.fatal('NaN in reports: %s; breaking...',
                                      report_template.format(**report_vals))

                else:
                    '''if train_itr == 10030:
                        print('here 10030')
                        train_itr, _ = sess.run(train_tensors, options=run_options, run_metadata=run_metadata)
                        tl = timeline.Timeline(run_metadata.step_stats)
                        ctf = tl.generate_chrome_trace_format()
                        tl_str = 'trainTL_' + config.dataset + '_' + str(config.batch_size)
                        with open(tl_str, 'w') as f:
                            f.write(ctf)
                        break
                    else:'''
                    train_itr, _ = sess.run(train_tensors)
    fig_time_str = config.dataset + '_' + str(config.batch_size) + '_time.npy'
    fig_train_str = config.dataset + '_' + str(
        config.batch_size) + '_train.npy'
    fig_valid_str = config.dataset + '_' + str(
        config.batch_size) + '_valid.npy'

    np.save(fig_time_str, fig_time)
    np.save(fig_train_str, fig_train)
    np.save(fig_valid_str, fig_valid)
    print('here we can plot the curve')
    fig, ax1 = plt.subplots()
    ax2 = ax1.twinx()
    lns1 = ax1.plot(fig_time, fig_train, label="train_accuracy")
    lns2 = ax2.plot(fig_time, fig_valid, 'r', label="valid_accuracy")
    ax1.set_xlabel('time')
    ax1.set_ylabel('accuracy')
    lns = lns1 + lns2
    labels = ["train", "valid"]
    plt.legend(lns, labels, loc=7)
    plt.savefig('./time_accuracy.jpg')