Esempio n. 1
0
    def get_tfrecords(self):

        # xtrain: all records
        # *_l   : partial records
        if self.dataset == 'CIFAR10':
            from cifar10 import inputs, unlabeled_inputs
            xtrain_l, ytrain_l = inputs(batch_size=self.batch_size,
                                        train=True,
                                        validation=False,
                                        shuffle=True)
            xtrain = unlabeled_inputs(batch_size=self.batch_size,
                                      validation=False,
                                      shuffle=True)
            xtest, ytest = inputs(batch_size=self.batch_size,
                                  train=False,
                                  validation=False,
                                  shuffle=True)
        elif self.dataset == 'SVHN':
            from svhn import inputs, unlabeled_inputs
            xtrain_l, ytrain_l = inputs(batch_size=self.batch_size,
                                        train=True,
                                        validation=False,
                                        shuffle=True)
            xtrain = unlabeled_inputs(batch_size=self.batch_size,
                                      validation=False,
                                      shuffle=True)
            xtest, ytest = inputs(batch_size=self.batch_size,
                                  train=False,
                                  validation=False,
                                  shuffle=True)
        elif self.dataset == 'MNIST':
            from mnist import inputs
            xtrain, _ = inputs(self.batch_size, 'train')
            xtrain_l, ytrain_l = inputs(self.batch_size, 'train_labeled')
            xtest, ytest = inputs(self.batch_size, 'test')

        return (xtrain_l, ytrain_l), xtrain, (xtest, ytest)
Esempio n. 2
0
    def get_tfrecords(self, idxes=None):

        """
        idxes: idxes of TEST set in k-fold cross varidation. i.e., list of i, where 0 <= i < k+1
        xtrain: all records
        *_l   : partial records
        """
        if c.FLAGS.dataset == 'CIFAR10':
            from cifar10 import inputs, unlabeled_inputs
            xtrain_l, ytrain_l = inputs(batch_size=c.BATCH_SIZE, train=True,  validation=False, shuffle=True)
            xtrain             = unlabeled_inputs(batch_size=c.BATCH_SIZE_UL,    validation=False, shuffle=True)
            xtest , ytest      = inputs(batch_size=c.BATCH_SIZE_TEST, train=False, validation=False, shuffle=True)
        elif c.FLAGS.dataset =='SVHN':
            from svhn import inputs, unlabeled_inputs
            xtrain_l, ytrain_l = inputs(batch_size=c.BATCH_SIZE, train=True,  validation=False, shuffle=True)
            xtrain             = unlabeled_inputs(batch_size=c.BATCH_SIZE_UL,    validation=False, shuffle=True)
            xtest , ytest      = inputs(batch_size=c.BATCH_SIZE_TEST, train=False, validation=False, shuffle=True)
        elif c.FLAGS.dataset == 'MNIST':
            from mnist import inputs
            xtrain_l, ytrain_l = inputs(c.BATCH_SIZE, 'train_labeled')
            xtrain,_           = inputs(c.BATCH_SIZE_UL, 'train')
            xtest , ytest      = inputs(c.BATCH_SIZE_TEST, 'test')

        else:
            indices_train, indices_test = _split_k_into_train_and_test(self._k, idxes)
            #indices_train = [0]
            #print('=====================================================')
            #print('CAUTION! n of tfrecord_train was forced to be reduced', indices_train)
            #print('=====================================================')

            paths_train = self.paths_tfrecord_train

            if idxes is None:
                """ use different data sources for training and test respectively."""
                if self.path_root_tests is None:
                    raise ValueError('path_root_tests was not given.')
                path_tests = self.path_tfrecord_test

            else:
                """ do cross validataion over a single data source. """

                if self.path_root_tests is not None:
                    raise ValueError('selection of test data source is confusing since both path_root_tests and idxes are given.')

                paths_test = self.paths_tfrecord_train
                                                                                                             
            """ DEBUG
            print('indices_train:',indices_train)
            print('indices_test:', indices_test)
            sys.exit('oshimai')
            """

            print('... reading TFRecords for train')
            (xtrain_l, ytrain_l), n_train_l = self.inputs(paths_train, indices_train, batch_size=c.BATCH_SIZE)
                                                                                                             
            print('... reading TFRecords for test')
            (xtest, ytest), n_test = self.inputs(paths_test, indices_test, batch_size=c.BATCH_SIZE_TEST)
                                                                                                             
            if IS_UNLABELED_ENABLE:
                print('... reading TFRecords of unlabeled')
                (xtrain, _), n_train_u = self.inputs(paths_train, indices_all, batch_size=c.BATCH_SIZE_UL )
            else:
                xtrain = xtrain_l
                n_train_u = 0
                                                                                                             
            """ CAUTION!
                in SSL setting, x_l and x_ul are in parallel and independently input to the model.
                bellow n_batches_train would be used as the num of training iteration per epoch,
                but when c.BATCH_SIZE and c.BATCH_SIZE_UL are different, n_batches_train is no longer valid.
                so, instead of n_batches_train, the num of training iteration per epoch should be given as arg, as vat original code.
            """
            #self.n_batches_train  = int((n_train_l + n_train_u) / (c.BATCH_SIZE + c.BATCH_SIZE_UL))
            self.n_batches_train  = int((n_train_l + n_train_u) / (c.BATCH_SIZE ))

            self.n_batches_test   = int( n_test / c.BATCH_SIZE_TEST)
            print(' n_batches_train: %d, # n_batches_test : %d'%(self.n_batches_train, self.n_batches_test))
            # [ToDo] the way to set n_labeled is inconsistent in datasets.
            self.n_labeled        = n_train_l
            self.n_train          = n_train_l + n_train_u

        return (xtrain_l, ytrain_l), xtrain, (xtest , ytest)
Esempio n. 3
0
def main(_):
    print(FLAGS.epsilon, FLAGS.top_bn)
    numpy.random.seed(seed=FLAGS.seed)
    tf.set_random_seed(numpy.random.randint(1234))
    with tf.Graph().as_default() as g:
        with tf.device("/cpu:0"):
            images, labels = inputs(batch_size=FLAGS.batch_size,
                                    train=True,
                                    validation=FLAGS.validation,
                                    shuffle=True)
            ul_images = unlabeled_inputs(batch_size=FLAGS.ul_batch_size,
                                         validation=FLAGS.validation,
                                         shuffle=True)

            images_eval_train, labels_eval_train = inputs(
                batch_size=FLAGS.eval_batch_size,
                train=True,
                validation=FLAGS.validation,
                shuffle=True)
            ul_images_eval_train = unlabeled_inputs(
                batch_size=FLAGS.eval_batch_size,
                validation=FLAGS.validation,
                shuffle=True)

            images_eval_test, labels_eval_test = inputs(
                batch_size=FLAGS.eval_batch_size,
                train=False,
                validation=FLAGS.validation,
                shuffle=True)

        with tf.device(FLAGS.device):
            lr = tf.placeholder(tf.float32, shape=[], name="learning_rate")
            mom = tf.placeholder(tf.float32, shape=[], name="momentum")
            with tf.variable_scope("CNN") as scope:
                # Build training graph
                loss, train_op, global_step = build_training_graph(
                    images, labels, ul_images, lr, mom)
                scope.reuse_variables()
                # Build eval graph
                losses_eval_train = build_eval_graph(images_eval_train,
                                                     labels_eval_train,
                                                     ul_images_eval_train)
                losses_eval_test = build_eval_graph(images_eval_test,
                                                    labels_eval_test,
                                                    images_eval_test)

            init_op = tf.global_variables_initializer()

        if not FLAGS.log_dir:
            logdir = None
            writer_train = None
            writer_test = None
        else:
            logdir = FLAGS.log_dir
            writer_train = tf.summary.FileWriter(FLAGS.log_dir + "/train", g)
            writer_test = tf.summary.FileWriter(FLAGS.log_dir + "/test", g)

        saver = tf.train.Saver(tf.global_variables())
        sv = tf.train.Supervisor(is_chief=True,
                                 logdir=logdir,
                                 init_op=init_op,
                                 init_feed_dict={
                                     lr: FLAGS.learning_rate,
                                     mom: FLAGS.mom1
                                 },
                                 saver=saver,
                                 global_step=global_step,
                                 summary_op=None,
                                 summary_writer=None,
                                 save_model_secs=150,
                                 recovery_wait_secs=0)

        print("Training...")
        with sv.managed_session() as sess:
            for ep in range(FLAGS.num_epochs):
                if sv.should_stop():
                    break

                if ep < FLAGS.epoch_decay_start:
                    feed_dict = {lr: FLAGS.learning_rate, mom: FLAGS.mom1}
                else:
                    decayed_lr = (
                        (FLAGS.num_epochs - ep) /
                        float(FLAGS.num_epochs -
                              FLAGS.epoch_decay_start)) * FLAGS.learning_rate
                    feed_dict = {lr: decayed_lr, mom: FLAGS.mom2}

                sum_loss = 0
                start = time.time()
                for i in range(FLAGS.num_iter_per_epoch):
                    _, batch_loss, _ = sess.run([train_op, loss, global_step],
                                                feed_dict=feed_dict)
                    sum_loss += batch_loss
                end = time.time()
                print("Epoch:", ep, "CE_loss_train:",
                      sum_loss / FLAGS.num_iter_per_epoch, "elapsed_time:",
                      end - start)

                if (ep + 1
                    ) % FLAGS.eval_freq == 0 or ep + 1 == FLAGS.num_epochs:
                    # Eval on training data
                    act_values_dict = {}
                    for key, _ in losses_eval_train.iteritems():
                        act_values_dict[key] = 0
                    n_iter_per_epoch = NUM_EVAL_EXAMPLES / FLAGS.eval_batch_size
                    for i in range(n_iter_per_epoch):
                        values = losses_eval_train.values()
                        act_values = sess.run(values)
                        for key, value in zip(act_values_dict.keys(),
                                              act_values):
                            act_values_dict[key] += value
                    summary = tf.Summary()
                    current_global_step = sess.run(global_step)
                    for key, value in act_values_dict.iteritems():
                        print("train-" + key, value / n_iter_per_epoch)
                        summary.value.add(tag=key,
                                          simple_value=value /
                                          n_iter_per_epoch)
                    if writer_train is not None:
                        writer_train.add_summary(summary, current_global_step)

                    # Eval on test data
                    act_values_dict = {}
                    for key, _ in losses_eval_test.iteritems():
                        act_values_dict[key] = 0
                    n_iter_per_epoch = NUM_EVAL_EXAMPLES / FLAGS.eval_batch_size
                    for i in range(n_iter_per_epoch):
                        values = losses_eval_test.values()
                        act_values = sess.run(values)
                        for key, value in zip(act_values_dict.keys(),
                                              act_values):
                            act_values_dict[key] += value
                    summary = tf.Summary()
                    current_global_step = sess.run(global_step)
                    for key, value in act_values_dict.iteritems():
                        print("test-" + key, value / n_iter_per_epoch)
                        summary.value.add(tag=key,
                                          simple_value=value /
                                          n_iter_per_epoch)
                    if writer_test is not None:
                        writer_test.add_summary(summary, current_global_step)

            saver.save(sess, sv.save_path, global_step=global_step)
        sv.stop()
Esempio n. 4
0
def main(_):
    print(FLAGS.epsilon, FLAGS.top_bn)
    np.random.seed(seed=FLAGS.seed)
    tf.set_random_seed(np.random.randint(1234))
    with tf.Graph().as_default() as g:
        with tf.device("/cpu:0"):

            if FLAGS.data_set == 'CelebA':
                (images, labels), (_,_),(_,_) = d.get_data(batch_size=FLAGS.batch_size,image_size=FLAGS.img_size)

                (images_eval_train, labels_eval_train), (_,_),(images_eval_test, labels_eval_test) = \
                    d.get_data(batch_size=FLAGS.eval_batch_size,image_size=FLAGS.img_size)

                ul_images = images
                ul_images_eval_train = images_eval_train
            else:
                images, labels = inputs(batch_size=FLAGS.batch_size,
                                        train=True,
                                        validation=FLAGS.validation,
                                        shuffle=True)
                ul_images = unlabeled_inputs(batch_size=FLAGS.ul_batch_size,
                                             validation=FLAGS.validation,
                                             shuffle=True)
                                                                                                
                images_eval_train, labels_eval_train = inputs(batch_size=FLAGS.eval_batch_size,
                                                              train=True,
                                                              validation=FLAGS.validation,
                                                              shuffle=True)
                ul_images_eval_train = unlabeled_inputs(batch_size=FLAGS.eval_batch_size,
                                                        validation=FLAGS.validation,
                                                        shuffle=True)
                                                                                                
                images_eval_test, labels_eval_test = inputs(batch_size=FLAGS.eval_batch_size,
                                                            train=False,
                                                            validation=FLAGS.validation,
                                                            shuffle=True)

        lr = tf.placeholder(tf.float32, shape=[], name="learning_rate")
        mom = tf.placeholder(tf.float32, shape=[], name="momentum")

        loss, train_op, x_adv, x_reconst = build_training_graph(images, labels, ul_images, lr, mom)

        # Build eval graph
        if not FLAGS.draw_adv_img:
            losses_eval_train, _ = build_eval_graph(images_eval_train, labels_eval_train, ul_images_eval_train)
            losses_eval_test, results  = build_eval_graph(images_eval_test, labels_eval_test, images_eval_test)

        saver = tf.train.Saver()
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.gpu_options.allocator_type = 'BFC'
        sess = tf.Session(config=config)

        if FLAGS.method == 'lvat':
            print('-------------------------------------------')
            print("... restore the variables from frozen model.")  
            u.restore(sess, SCOPE_ENCODER, CKPT_AE)
            print('-------------------------------------------')
            

        #if FLAGS.draw_adv_img:
        if False:
            print("... restore the variables of the classifier. log__dir:", FLAGS.log__dir)
            ckpt = tf.train.get_checkpoint_state(FLAGS.log__dir)
            if ckpt and ckpt.model_checkpoint_path:
                u.restore(sess, SCOPE_CLASSIFIER, FLAGS.log__dir)
                op_init = u.init_uninitialized_vars(sess)
                sess.run(op_init, feed_dict={lr: FLAGS.learning_rate, mom: FLAGS.mom1})
            else:
                sys.exit('failed to restore')
        else:
            print("... init the variables for the classifier to be trained.")


            classifier_vars = tf.get_collection( tf.GraphKeys.VARIABLES, scope=SCOPE_CLASSIFIER)
            print('classifier_vars:', classifier_vars)
            op_init = tf.variables_initializer(classifier_vars)

            optimizer_vars = tf.get_collection( tf.GraphKeys.VARIABLES, scope='scope_optimizer')
            print('optimizer_vars:', optimizer_vars)
            op_init_optimiser = tf.variables_initializer(optimizer_vars)
            sess.run([op_init, op_init_optimiser], feed_dict={lr: FLAGS.learning_rate, mom: FLAGS.mom1})


        tf.train.start_queue_runners(sess=sess)

        if FLAGS.draw_adv_img:

            print('... skip training')

            _x, _x_adv, _x_reconst = sess.run([ul_images, x_adv, x_reconst])
            _N = 7
            print(math.floor(FLAGS.ul_batch_size // _N))

            for i in range(math.floor(FLAGS.ul_batch_size // _N)):
                draw_x(_x, _x_reconst, _x_adv, n_x=_N, offset=i, show_reconst=(FLAGS.method == 'lvat'),
                        filename='ep_%s_%.2f_%d'%(FLAGS.method, FLAGS.epsilon, i))

            sys.exit('exit draw_adv_img')

        else:

            print('... start training')

            for ep in range(FLAGS.num_epochs):
                if ep < FLAGS.epoch_decay_start:
                    feed_dict = {lr: FLAGS.learning_rate, mom: FLAGS.mom1}
                else:
                    decayed_lr = ((FLAGS.num_epochs - ep) / float(
                        FLAGS.num_epochs - FLAGS.epoch_decay_start)) * FLAGS.learning_rate
                    feed_dict = {lr: decayed_lr, mom: FLAGS.mom2}

                sum_loss = 0
                start = time.time()
                for i in tqdm(range(FLAGS.num_iter_per_epoch), leave=False):

                    _, batch_loss = sess.run([train_op, loss], feed_dict=feed_dict)

                    sum_loss += batch_loss
                end = time.time()
                print("Epoch:", ep, "CE_loss_train:", sum_loss / FLAGS.num_iter_per_epoch, "elapsed_time:", end - start, flush=True)

                if (ep >= FLAGS.eval_start) and ((ep + 1) % FLAGS.eval_freq == 0 or ep + 1 == FLAGS.num_epochs):

                    test(sess, losses_eval_train, ep, "train-")
                    test(sess, losses_eval_test,  ep, "test-")

                if ep % 10 == 0:
                    print("Model saved in file: %s" % saver.save(sess, FLAGS.log__dir + '/model.ckpt'))
    return
Esempio n. 5
0
def main(_):
    print(FLAGS.epsilon, FLAGS.top_bn)
    np.random.seed(seed=FLAGS.seed)
    tf.set_random_seed(np.random.randint(1234))
    with tf.Graph().as_default() as g:
        with tf.device("/cpu:0"):
            images, labels = inputs(batch_size=FLAGS.batch_size,
                                    train=True,
                                    validation=FLAGS.validation,
                                    shuffle=True)
            ul_images = tf.placeholder(shape=images.shape, dtype=tf.float32)
            '''unlabeled_inputs(batch_size=FLAGS.ul_batch_size,
                                         validation=FLAGS.validation,
                                         shuffle=True)'''

            images_eval_train, labels_eval_train = inputs(batch_size=FLAGS.eval_batch_size,
                                                          train=True,
                                                          validation=FLAGS.validation,
                                                          shuffle=True)
            ul_images_eval_train = unlabeled_inputs(batch_size=FLAGS.eval_batch_size,
                                                    validation=FLAGS.validation,
                                                    shuffle=True)

            images_eval_test, labels_eval_test = inputs(batch_size=FLAGS.eval_batch_size,
                                                        train=False,
                                                        validation=FLAGS.validation,
                                                        shuffle=True)

            def placeholder_like(x, name=None):
                return tf.placeholder(shape=x.shape, dtype=tf.float32, name=name)

            def random_sphere(shape):
                n = tf.random_normal(shape=shape, dtype=tf.float32)
                n = tf.reshape(n, shape=(int(shape[0]), -1))
                n = tf.nn.l2_normalize(n, dim=1)
                n = tf.reshape(n, shape)
                return n

            def random_sphere_numpy(shape):
                n = np.random.normal(size=shape)
                proj_shape = tuple([n.shape[0]] + [1 for _ in range(len(shape) - 1)])
                return n / np.linalg.norm(n.reshape((n.shape[0], -1)), axis=1).reshape(proj_shape)

            print(ul_images.shape)
            # ul_u = random_sphere(ul_images.shape)
            # ul_u_eval_train = random_sphere(ul_images_eval_train.shape)
            # ul_u_eval_test = random_sphere(images_eval_test.shape)
            ul_u = placeholder_like(ul_images, "ul_u")
            ul_u_eval_train = placeholder_like(ul_images_eval_train, "ul_u_eval_train")
            ul_u_eval_test =  placeholder_like(images_eval_test, "ul_u_eval_test")

        with tf.device(FLAGS.device):
            lr = tf.placeholder(tf.float32, shape=[], name="learning_rate")
            mom = tf.placeholder(tf.float32, shape=[], name="momentum")
            with tf.variable_scope("CNN") as scope:
                # Build training graph
                loss, train_op, global_step, ul_u_updated = build_training_graph(
                                                                images, labels, ul_images, ul_u, lr, mom)
                scope.reuse_variables()
                # Build eval graph
                losses_eval_train = build_eval_graph(images_eval_train, labels_eval_train, ul_images_eval_train, ul_u_eval_train)
                losses_eval_test = build_eval_graph(images_eval_test, labels_eval_test, images_eval_test, ul_u_eval_test)

            init_op = tf.global_variables_initializer()

        if not FLAGS.log_dir:
            logdir = None
            writer_train = None
            writer_test = None
        else:
            logdir = FLAGS.log_dir
            writer_train = tf.summary.FileWriter(FLAGS.log_dir + "/train", g)
            writer_test = tf.summary.FileWriter(FLAGS.log_dir + "/test", g)

        saver = tf.train.Saver(tf.global_variables())
        sv = tf.train.Supervisor(
            is_chief=True,
            logdir=logdir,
            init_op=init_op,
            init_feed_dict={lr: FLAGS.learning_rate, mom: FLAGS.mom1},
            saver=saver,
            global_step=global_step,
            summary_op=None,
            summary_writer=None,
            save_model_secs=150, recovery_wait_secs=0)

        ul_images_np = np.load("train_images.npy").reshape((-1, 32, 32, 3))
        print("TRUNCATING UL DATA")
        ul_images_np = ul_images_np[:FLAGS.batch_size]
        ul_u_np = random_sphere_numpy(ul_images_np.shape)
        print(ul_images_np.shape, ul_u_np.shape)

        print("Training...")
        with sv.managed_session() as sess:
            for ep in range(FLAGS.num_epochs):
                if sv.should_stop():
                    break

                if ep < FLAGS.epoch_decay_start:
                    feed_dict = {lr: FLAGS.learning_rate, mom: FLAGS.mom1}
                else:
                    decayed_lr = ((FLAGS.num_epochs - ep) / float(
                        FLAGS.num_epochs - FLAGS.epoch_decay_start)) * FLAGS.learning_rate
                    feed_dict = {lr: decayed_lr, mom: FLAGS.mom2}

                sum_loss = 0
                start = time.time()
                for i in range(FLAGS.num_iter_per_epoch):
                    picked = range(FLAGS.batch_size) # np.random.choice(len(ul_images_np), size=FLAGS.batch_size, replace=False)
                    feed_dict[ul_images] = ul_images_np[picked]
                    feed_dict[ul_u] = ul_u_np[picked]
                    ul_u_updated_np, _, batch_loss, _ = sess.run([ul_u_updated, train_op, loss, global_step],
                                                feed_dict=feed_dict)
                    delta = ul_u_updated_np - ul_u_np[picked]
                    # print("pos", ul_u_updated_np.reshape((FLAGS.batch_size, -1))[0, :4])
                    # print("delta", np.linalg.norm(delta.reshape((FLAGS.batch_size, -1)), axis=1)[:4])
                    print(np.linalg.norm(ul_u_updated_np - ul_u_np[picked]), ul_u_updated_np.reshape((FLAGS.batch_size, -1))[0, :3])
                    ul_u_np[picked] = ul_u_updated_np
                    sum_loss += batch_loss
                end = time.time()
                print("Epoch:", ep, "CE_loss_train:", sum_loss / FLAGS.num_iter_per_epoch, "elapsed_time:", end - start)

                if (ep + 1) % FLAGS.eval_freq == 0 or ep + 1 == FLAGS.num_epochs:
                    # Eval on training data
                    act_values_dict = {}
                    feed_dict = {ul_u_eval_train: random_sphere_numpy(ul_u_eval_train.shape)}
                    for key, _ in losses_eval_train.iteritems():
                        act_values_dict[key] = 0
                    n_iter_per_epoch = NUM_EVAL_EXAMPLES / FLAGS.eval_batch_size
                    for i in range(n_iter_per_epoch):
                        values = losses_eval_train.values()
                        act_values = sess.run(values, feed_dict=feed_dict)
                        for key, value in zip(act_values_dict.keys(), act_values):
                            act_values_dict[key] += value
                    summary = tf.Summary()
                    current_global_step = sess.run(global_step)
                    for key, value in act_values_dict.iteritems():
                        print("train-" + key, value / n_iter_per_epoch)
                        summary.value.add(tag=key, simple_value=value / n_iter_per_epoch)
                    if writer_train is not None:
                        writer_train.add_summary(summary, current_global_step)

                    # Eval on test data
                    act_values_dict = {}
                    print("HOW COME THIS DOES NOT DEPEND ON ul_images_eval_train? SOMETHING'S WRONG HERE.")
                    feed_dict = {ul_u_eval_test: random_sphere_numpy(ul_u_eval_test.shape)}
                    for key, _ in losses_eval_test.iteritems():
                        act_values_dict[key] = 0
                    n_iter_per_epoch = NUM_EVAL_EXAMPLES / FLAGS.eval_batch_size
                    for i in range(n_iter_per_epoch):
                        values = losses_eval_test.values()
                        act_values = sess.run(values, feed_dict=feed_dict)
                        for key, value in zip(act_values_dict.keys(), act_values):
                            act_values_dict[key] += value
                    summary = tf.Summary()
                    current_global_step = sess.run(global_step)
                    for key, value in act_values_dict.iteritems():
                        print("test-" + key, value / n_iter_per_epoch)
                        summary.value.add(tag=key, simple_value=value / n_iter_per_epoch)
                    if writer_test is not None:
                        writer_test.add_summary(summary, current_global_step)

            saver.save(sess, sv.save_path, global_step=global_step)
        sv.stop()