コード例 #1
0
    def train(self, train_nimg, report_nimg):
        if FLAGS.eval_ckpt:
            self.eval_checkpoint(FLAGS.eval_ckpt)
            return
        batch = FLAGS.batch
        train_labeled = self.dataset.train_labeled.batch(batch).prefetch(16)
        train_labeled = train_labeled.make_one_shot_iterator().get_next()
        train_unlabeled = self.dataset.train_unlabeled.batch(batch).prefetch(
            16)
        train_unlabeled = train_unlabeled.make_one_shot_iterator().get_next()
        scaffold = tf.train.Scaffold(saver=tf.train.Saver(
            max_to_keep=FLAGS.keep_ckpt, pad_step_number=10))

        with tf.Session(config=utils.get_config()) as sess:
            self.session = sess
            self.cache_eval()

        with tf.train.MonitoredTrainingSession(
                scaffold=scaffold,
                checkpoint_dir=self.checkpoint_dir,
                config=utils.get_config(),
                save_checkpoint_steps=FLAGS.save_kimg << 10,
                save_summaries_steps=report_nimg - batch) as train_session:
            self.session = train_session._tf_sess()
            self.tmp.step = self.session.run(self.step)
            while self.tmp.step < train_nimg:
                loop = trange(self.tmp.step % report_nimg,
                              report_nimg,
                              batch,
                              leave=False,
                              unit='img',
                              unit_scale=batch,
                              desc='Epoch %d/%d' %
                              (1 + (self.tmp.step // report_nimg),
                               train_nimg // report_nimg))
                for _ in loop:
                    self.train_step(train_session, train_labeled,
                                    train_unlabeled)
                    while self.tmp.print_queue:
                        loop.write(self.tmp.print_queue.pop(0))
            while self.tmp.print_queue:
                print(self.tmp.print_queue.pop(0))
コード例 #2
0
 def eval_mode(self, ckpt=None):
     self.session = tf.Session(config=utils.get_config())
     saver = tf.train.Saver()
     if ckpt is None:
         ckpt = utils.find_latest_checkpoint(self.checkpoint_dir)
     else:
         ckpt = os.path.abspath(ckpt)
     saver.restore(self.session, ckpt)
     self.tmp.step = self.session.run(self.step)
     print('Eval model %s at global_step %d' % (self.__class__.__name__, self.tmp.step))
     return self
コード例 #3
0
 def get_size(self):
     """compute the size to of the dataset."""
     data = []
     with tf.Session(config=utils.get_config()) as session:
         it = self.parse().prefetch(16).make_one_shot_iterator().get_next()
         try:
             while 1:
                 data.append(session.run(it))
         except tf.errors.OutOfRangeError:
             pass
     return len(data)
コード例 #4
0
ファイル: train.py プロジェクト: waltersharpWEI/mixmatch
    def eval_stats(self, batch=None, feed_extra=None, classify_op=None):
        def collect_samples(data):
            data_it = data.batch(1).prefetch(16).make_one_shot_iterator().get_next()
            images, labels = [], []
            while 1:
                try:
                    v = sess_data.run(data_it)
                except tf.errors.OutOfRangeError:
                    break
                images.append(v['image'])
                labels.append(v['label'])

            images = np.concatenate(images, axis=0)
            labels = np.concatenate(labels, axis=0)
            return images, labels

        if 'test' not in self.tmp.cache:
            with tf.Graph().as_default(), tf.Session(config=utils.get_config()) as sess_data:
                self.tmp.cache.test = collect_samples(self.dataset.test)
                self.tmp.cache.valid = collect_samples(self.dataset.valid)
                self.tmp.cache.train_labeled = collect_samples(self.dataset.eval_labeled)

        batch = batch or FLAGS.batch
        classify_op = self.ops.classify_op if classify_op is None else classify_op
        accuracies = []
        for subset in ('train_labeled', 'valid', 'test'):
            images, labels = self.tmp.cache[subset]

            predicted = np.concatenate([
                self.session.run(classify_op, feed_dict={
                    self.ops.x: images[x:x + batch], **(feed_extra or {})})
                for x in range(0, images.shape[0], batch)
            ], axis=0)
            accuracies.append((predicted.argmax(1) == labels).mean() * 100)
        self.train_print('kimg %-5d  accuracy train/valid/test  %.2f  %.2f  %.2f' %
                         tuple([self.tmp.step >> 10] + accuracies))
        
        time_budget = FLAGS.time_budget       
        elapsed = (time.clock() - start)
        if time_budget != None:
            if elapsed >= time_budget:
               # print("Time elapsed %d" % elapsed)
                print(self.tmp.step)
                print(elapsed)
                print(accuracies[2])
                exit(0)
        target_accuracy = FLAGS.target_accuracy
        if target_accuracy != None:
            if float(accuracies[2]) >= float(target_accuracy * 100.0):
               print(self.tmp.step)
               print(elapsed)
               print(accuracies[2], target_accuracy)
               exit(0)
        return np.array(accuracies, 'f')
コード例 #5
0
def main(argv):
    del argv
    utils.setup_tf()
    dataset = DATASETS[FLAGS.dataset]()
    with tf.Session(config=utils.get_config()) as sess:
        hashes = (collect_hashes(sess, 'labeled', dataset.eval_labeled),
                  collect_hashes(sess, 'unlabeled', dataset.eval_unlabeled),
                  collect_hashes(sess, 'validation', dataset.valid),
                  collect_hashes(sess, 'test', dataset.test))
    print('Overlap matrix (should be an almost perfect diagonal matrix with counts).')
    groups = 'labeled unlabeled validation test'.split()
    fmt = '%-10s %10s %10s %10s %10s'
    print(fmt % tuple([''] + groups))
    for p, x in enumerate(hashes):
        overlaps = [len(x & y) for y in hashes]
        print(fmt % tuple([groups[p]] + overlaps))
コード例 #6
0
ファイル: train.py プロジェクト: waltersharpWEI/mixmatch
    def train(self, train_nimg, report_nimg):
        start = time.clock()
        if FLAGS.eval_ckpt:
            self.eval_checkpoint(FLAGS.eval_ckpt)
            return
        batch = FLAGS.batch
#        target_loss = FLAGS.target_loss
#        if target_loss != None:
#          print("Target loss : %.2f" % target_loss)
#        else: 
        print("Target loss : 0.0")
        target_accuracy = FLAGS.target_accuracy
        if target_accuracy != None:
          print("Target accuracy: %.2f" % target_accuracy)
        else:
          print("Target accuracy: 0.0")
        with self.graph.as_default():
            train_labeled = self.dataset.train_labeled.batch(batch).prefetch(16)
            train_labeled = train_labeled.make_one_shot_iterator().get_next()
            sigma = FLAGS.sigma
            print("Sigma: %d" % sigma)
            train_unlabeled = self.dataset.train_unlabeled.batch(batch).prefetch(sigma)
            train_unlabeled = train_unlabeled.make_one_shot_iterator().get_next()
            scaffold = tf.train.Scaffold(saver=tf.train.Saver(max_to_keep=FLAGS.keep_ckpt,
                                                              pad_step_number=10))

            with tf.train.MonitoredTrainingSession(
                    scaffold=scaffold,
                    checkpoint_dir=self.checkpoint_dir,
                    config=utils.get_config(),
                    save_checkpoint_steps=FLAGS.save_kimg << 10,
                    save_summaries_steps=report_nimg - batch) as train_session:
                self.session = train_session._tf_sess()
                self.tmp.step = self.session.run(self.step)
                while self.tmp.step < train_nimg:
                    loop = trange(self.tmp.step % report_nimg, report_nimg, batch,
                                  leave=False, unit='img', unit_scale=batch,
                                  desc='Epoch %d/%d' % (1 + (self.tmp.step // report_nimg), train_nimg // report_nimg))
                    for _ in loop:
                        self.train_step(train_session, train_labeled, train_unlabeled)
                        while self.tmp.print_queue:
                            loop.write(self.tmp.print_queue.pop(0))
                while self.tmp.print_queue:
                    print(self.tmp.print_queue.pop(0))
コード例 #7
0
def compute_mean_std(data: tf.data.Dataset):
    data = data.map(lambda x: x['image']).batch(1024).prefetch(1)
    data = data.make_one_shot_iterator().get_next()
    count = 0
    stats = []
    with tf.Session(config=utils.get_config()) as sess:
        def iterator():
            while True:
                try:
                    yield sess.run(data)
                except tf.errors.OutOfRangeError:
                    break

        for batch in tqdm(iterator(), unit='kimg', desc='Computing dataset mean and std'):
            ratio = batch.shape[0] / 1024.
            count += ratio
            stats.append((batch.mean((0, 1, 2)) * ratio, (batch ** 2).mean((0, 1, 2)) * ratio))
    mean = sum(x[0] for x in stats) / count
    sigma = sum(x[1] for x in stats) / count - mean ** 2
    std = np.sqrt(sigma)
    print('Mean %s  Std: %s' % (mean, std))
    return mean, std
コード例 #8
0
def memoize(dataset: tf.data.Dataset) -> tf.data.Dataset:
    data = []
    with tf.Graph().as_default(), tf.Session(config=utils.get_config()) as session:
        dataset = dataset.prefetch(16)
        it = dataset.make_one_shot_iterator().get_next()
        try:
            while 1:
                data.append(session.run(it))
        except tf.errors.OutOfRangeError:
            pass
    images = np.stack([x['image'] for x in data])
    labels = np.stack([x['label'] for x in data])

    def tf_get(index):
        def get(index):
            return images[index], labels[index]

        image, label = tf.py_func(get, [index], [tf.float32, tf.int64])
        return dict(image=image, label=label)

    dataset = tf.data.Dataset.range(len(data)).repeat()
    dataset = dataset.shuffle(len(data) if len(data) < FLAGS.shuffle else FLAGS.shuffle)
    return dataset.map(tf_get)
コード例 #9
0
ファイル: data.py プロジェクト: muskanmahajan37/mma
    def __init__(self,
                 name,
                 graph,
                 train_filenames,
                 test_filenames,
                 parse_fn=record_parse,
                 augment=(lambda x: x, lambda x: x),
                 height=32,
                 width=32,
                 colors=3,
                 nclass=10,
                 mean=0,
                 std=1,
                 p_labeled=None,
                 p_unlabeled=None):
        self.name = name
        self.graph = graph
        self.session = tf.Session(config=utils.get_config(), graph=self.graph)
        self.images, self.labels = self.dataset_numpy(train_filenames,
                                                      parse_fn)
        self.ntrain = self.images.shape[0]
        with self.graph.as_default():
            self.test = default_parse(dataset(test_filenames), parse_fn)
        self.height = height
        self.width = width
        self.colors = colors
        self.nclass = nclass
        self.augment = augment

        self.all_indices = None  # all indices used here. None means using all data
        self.labeled_indices, self.unlabeled_indices = None, None
        self.no_label_indices = None

        self.mean = mean
        self.std = std
        self.p_labeled = p_labeled
        self.p_unlabeled = p_unlabeled
コード例 #10
0
    def memoize(self):
        """Call before parsing, since it calls for parse inside."""
        data = []
        with tf.Session(config=utils.get_config()) as session:
            it = self.parse().prefetch(16).make_one_shot_iterator().get_next()
            try:
                while 1:
                    data.append(session.run(it))
            except tf.errors.OutOfRangeError:
                pass
        images = np.stack([x['image'] for x in data])
        labels = np.stack([x['label'] for x in data])

        def tf_get(index, image_shape):
            def get(index):
                return images[index], labels[index]

            image, label = tf.py_func(get, [index], [tf.float32, tf.int64])
            return dict(image=tf.reshape(image, image_shape), label=label, index=index)

        return self.__class__(tf.data.Dataset.range(len(data)),
                              parse_fn=tf_get,
                              augment_fn=self.augment_fn,
                              image_shape=self.image_shape)
コード例 #11
0
ファイル: train.py プロジェクト: muskanmahajan37/mma
    def train_for_contGrow(self, train_nimg, past_nimg, report_nimg, grow_nimg,
                           grow_size, max_labeled_size):
        """Function for training the model.

    Args:
      train_nimg: will train for train_nimg/batch iterations
      past_nimg: has previously trained for train_nimg/batch iterations
      report_nimg: report results every report_nimg samples
      grow_nimg: grow every grow_nimg samples
      grow_size: number of samples to query each time
      max_labeled_size: maximum labelling budget
    """
        if max_labeled_size == -1:
            max_labeled_size = self.dataset.labeled_indices.size + self.dataset.unlabeled_indices.size

        if grow_nimg > 0:
            print('grow_kimg:', grow_nimg >> 10)
            print('grow_by: ', FLAGS.grow_by)
            print('grow_size:', grow_size)
        else:
            grow_nimg = train_nimg
            print('Will not grow.')
        print('----')

        if FLAGS.eval_ckpt:
            accurices = self.eval_checkpoint(FLAGS.eval_ckpt)
            return
        batch = FLAGS.batch
        scaffold = tf.train.Scaffold(saver=tf.train.Saver(
            max_to_keep=FLAGS.keep_ckpt, pad_step_number=10))
        with tf.train.MonitoredTrainingSession(
                scaffold=scaffold,
                checkpoint_dir=self.checkpoint_dir,
                config=utils.get_config(),
                save_checkpoint_steps=FLAGS.save_kimg << 10,
                save_summaries_steps=report_nimg - batch) as train_session:

            self.session = train_session._tf_sess()
            self.tmp.step = self.session.run(self.step)

            need_update = True
            while self.tmp.step < train_nimg:
                if grow_nimg > 0 and (self.tmp.step -
                                      past_nimg) % grow_nimg == 0:
                    # Grow
                    with self.dataset.graph.as_default():
                        labeled_indices = utils.fixlen_to_idx(
                            self.session.run(self.ops.label_index))
                        self.dataset.generate_labeled_and_unlabeled(
                            list(labeled_indices))
                        # Get unlabeled data
                        unlabeled_data = tf.data.Dataset.from_tensor_slices(self.dataset.unlabeled_indices) \
                            .map(self.dataset.tf_get) \
                            .map(self.dataset.augment[1]) \
                            .batch(batch) \
                            .prefetch(16) \
                            .make_one_shot_iterator() \
                            .get_next()  # not shuffled, not repeated
                    need_update |= self.grow_labeled(FLAGS.grow_by, grow_size,
                                                     max_labeled_size,
                                                     unlabeled_data)
                if need_update:
                    # If we need to update the labeled and unlabeled set to be used for training
                    need_update = False
                    labeled_indices = utils.fixlen_to_idx(
                        self.session.run(self.ops.label_index))
                    self.dataset.generate_labeled_and_unlabeled(
                        list(labeled_indices))
                    with self.dataset.graph.as_default():
                        train_labeled = tf.data.Dataset.from_tensor_slices(self.dataset.labeled_indices) \
                            .repeat() \
                            .shuffle(FLAGS.shuffle) \
                            .map(self.dataset.tf_get) \
                            .map(self.dataset.augment[0]) \
                            .batch(batch).prefetch(16)
                        train_labeled = train_labeled.make_one_shot_iterator(
                        ).get_next()
                        train_unlabeled = tf.data.Dataset.from_tensor_slices(self.dataset.unlabeled_indices) \
                            .repeat() \
                            .shuffle(FLAGS.shuffle) \
                            .map(self.dataset.tf_get) \
                            .map(self.dataset.augment[1]) \
                            .batch(batch) \
                            .prefetch(16)
                        train_unlabeled = train_unlabeled.make_one_shot_iterator(
                        ).get_next()
                    print('# of labeled/unlabeled samples to be used:',
                          self.dataset.labeled_indices.size,
                          self.dataset.unlabeled_indices.size)
                # The actual training
                loop = trange(self.tmp.step % report_nimg,
                              report_nimg,
                              batch,
                              leave=False,
                              unit='img',
                              unit_scale=batch,
                              desc='Epoch %d/%d' %
                              (1 + (self.tmp.step // report_nimg),
                               train_nimg // report_nimg))
                for _ in loop:
                    self.train_step(train_session, train_labeled,
                                    train_unlabeled)
                    while self.tmp.print_queue:
                        loop.write(self.tmp.print_queue.pop(0))
            while self.tmp.print_queue:
                print(self.tmp.print_queue.pop(0))