def main(args):
    """Get dataset hyperparameters."""
    assert len(args) == 2 and isinstance(args[1], str)
    dataset_name = args[1]
    coord_add = get_coord_add(dataset_name)
    dataset_size_train = get_dataset_size_train(dataset_name)
    dataset_size_test = get_dataset_size_test(dataset_name)
    num_classes = get_num_classes(dataset_name)
    create_inputs = get_create_inputs(dataset_name,
                                      is_train=False,
                                      epochs=cfg.epoch)
    """Set reproduciable random seed"""
    tf.set_random_seed(1234)

    with tf.Graph().as_default():
        num_batches_per_epoch_train = int(dataset_size_train / cfg.batch_size)
        num_batches_test = int(dataset_size_test / cfg.batch_size)

        batch_x, batch_labels = create_inputs()
        output = net.build_arch(batch_x,
                                coord_add,
                                is_train=False,
                                num_classes=num_classes)
        batch_acc = net.test_accuracy(output, batch_labels)
        saver = tf.train.Saver()

        step = 0

        summaries = []
        summaries.append(tf.summary.scalar('accuracy', batch_acc))
        summary_op = tf.summary.merge(summaries)

        with tf.Session() as sess:
            tf.train.start_queue_runners(sess=sess)
            summary_writer = tf.summary.FileWriter(
                cfg.test_logdir, graph=None)  # graph=sess.graph, huge!

            for epoch in range(cfg.epoch):
                # requires a regex to adapt the loss value in the file name here
                ckpt_re = re.compile()
                ckpt = os.path.join(
                    cfg.logdir,
                    'model.ckpt-%d' % (num_batches_per_epoch_train * epoch))
                saver.restore(sess, ckpt)

                accuracy_sum = 0
                for i in range(num_batches_test):
                    batch_acc_v, summary_str = sess.run(
                        [batch_acc, summary_op])
                    print('%d batches are tested.' % step)
                    summary_writer.add_summary(summary_str, step)

                    accuracy_sum += batch_acc_v

                    step += 1

                ave_acc = accuracy_sum / num_batches_test
                print('the average accuracy is %f' % ave_acc)
def main(args):
    # 1、设置GPU模式
    session_config = cfg.set_gpu()

    with tf.Graph().as_default():

        # 2、设置随机种子、读取数据batch、类别数
        tf.set_random_seed(1234)
        coord_add = cfg.get_coord_add(dataset_name)
        num_classes = cfg.get_num_classes(dataset_name)
        labels_txt = cfg.search_keyword_files(recognize_data_dir, recognize_labels_txt_keywords)
        labels_maps = cfg.read_label_txt_to_dict(labels_txt[0])


        with tf.Session(config=session_config) as sess:

            create_inputs = cfg.get_create_inputs(dataset_name, is_train=False, epochs=cfg.epoch)
            batch_x, batch_labels = create_inputs()


            # 3、初始化网络
            output, pose_out = net.build_arch(batch_x, coord_add, is_train=False, num_classes=num_classes)
            tf.logging.debug(pose_out.get_shape())
            results, labels = net.batch_results_and_labels(output, batch_labels)

            # 4、全局初始化和启动数据线程 (要放在初始化网络之后)
            coord, threads = cfg.init_variables_and_start_thread(sess)

            # 5、恢复model
            cfg.restore_model(sess, ckpt)

            # 6、求出全部预测值和标签list
            np_predicts_list = []
            np_lables_list = []
            for i in range(num_batches_test):
                np_results,np_labels = sess.run(
                    [results, labels])
                print(np_results)
                print(np_labels)
                np_predicts_list.extend(np_results)
                np_lables_list.extend(np_labels)

            np_predicts_list_str = str(np_predicts_list)
            np_lables_list_str = str(np_lables_list)
            with open('predicts_and_labels.txt','w') as f:
                f.write('predicts\r\n')
                f.write(np_predicts_list_str + '\r\n')
                f.write('labels\r\n')
                f.write(np_lables_list_str + '\r\n')

            cfg.stop_threads(coord,threads)
Exemplo n.º 3
0
def main(args):
    tf.set_random_seed(1234)
    coord_add = get_coord_add(dataset_name)
    dataset_size_train = get_dataset_size_train(dataset_name)
    dataset_size_test = get_dataset_size_test(dataset_name)
    num_classes = get_num_classes(dataset_name)
    create_inputs = get_create_inputs(dataset_name,
                                      is_train=False,
                                      epochs=cfg.epoch)

    with tf.Graph().as_default():
        num_batches_test = int(dataset_size_test / cfg.batch_size * 0.5)
        batch_x, batch_labels = create_inputs()
        output, pose_out = net.build_arch(batch_x,
                                          coord_add,
                                          is_train=False,
                                          num_classes=num_classes)
        tf.logging.debug(pose_out.get_shape())

        batch_acc = net.test_accuracy(output, batch_labels)
        saver = tf.train.Saver()
        session_config = tf.ConfigProto(
            device_count={'GPU': 0},
            gpu_options={
                'allow_growth': 1,
                # 'per_process_gpu_memory_fraction': 0.1,
                'visible_device_list': '0'
            },
            allow_soft_placement=True)
        with tf.Session(config=session_config) as sess:
            sess.run(tf.local_variables_initializer())
            sess.run(tf.global_variables_initializer())
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            mode_file = tf.train.latest_checkpoint(ckpt)
            saver.restore(sess, mode_file)

            accuracy_sum = 0
            for i in range(num_batches_test):
                batch_acc_v = sess.run([batch_acc])
                accuracy_sum += batch_acc_v[0]
                print(accuracy_sum)

            ave_acc = accuracy_sum / num_batches_test
            print('the average accuracy is %f' % ave_acc)
Exemplo n.º 4
0
def main(args):
    """Get dataset hyperparameters."""
    assert len(args) == 3 and isinstance(args[1], str) and isinstance(
        args[2], str)
    dataset_name = args[1]
    model_name = args[2]
    """Set reproduciable random seed"""
    tf.set_random_seed(1234)

    coord_add = get_coord_add(dataset_name)
    dataset_size_train = get_dataset_size_train(dataset_name)
    dataset_size_test = get_dataset_size_test(dataset_name)
    num_classes = get_num_classes(dataset_name)
    create_inputs = get_create_inputs(dataset_name,
                                      is_train=False,
                                      epochs=cfg.epoch)

    with tf.Graph().as_default():
        num_batches_per_epoch_train = int(dataset_size_train / cfg.batch_size)
        num_batches_test = 2  # int(dataset_size_test / cfg.batch_size * 0.1)

        batch_x, batch_labels = create_inputs()
        batch_squash = tf.divide(batch_x, 255.)
        batch_x_norm = slim.batch_norm(batch_x,
                                       center=False,
                                       is_training=False,
                                       trainable=False)
        output, pose_out = net.build_arch(batch_x_norm,
                                          coord_add,
                                          is_train=False,
                                          num_classes=num_classes)
        tf.logging.debug(pose_out.get_shape())

        batch_acc = net.test_accuracy(output, batch_labels)
        m_op = tf.constant(0.9)
        loss, spread_loss, mse, recon_img_squash = net.spread_loss(
            output, pose_out, batch_squash, batch_labels, m_op)
        tf.summary.scalar('spread_loss', spread_loss)
        tf.summary.scalar('reconstruction_loss', mse)
        tf.summary.scalar('all_loss', loss)
        data_size = int(batch_x.get_shape()[1])
        recon_img = tf.multiply(
            tf.reshape(recon_img_squash,
                       shape=[cfg.batch_size, data_size, data_size, 1]), 255.)
        orig_img = tf.reshape(batch_x,
                              shape=[cfg.batch_size, data_size, data_size, 1])
        tf.summary.image('orig_image', orig_img)
        tf.summary.image('recon_image', recon_img)
        saver = tf.train.Saver()

        step = 0

        tf.summary.scalar('accuracy', batch_acc)
        summary_op = tf.summary.merge_all()

        with tf.Session(
                config=tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=False)) as sess:
            sess.run(tf.local_variables_initializer())
            sess.run(tf.global_variables_initializer())

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            if not os.path.exists(cfg.test_logdir +
                                  '/{}/{}/'.format(model_name, dataset_name)):
                os.makedirs(cfg.test_logdir +
                            '/{}/{}/'.format(model_name, dataset_name))
            summary_writer = tf.summary.FileWriter(
                cfg.test_logdir + '/{}/{}/'.format(model_name, dataset_name),
                graph=sess.graph)  # graph=sess.graph, huge!

            files = os.listdir(cfg.logdir +
                               '/{}/{}/'.format(model_name, dataset_name))
            for epoch in range(45, 46):
                # requires a regex to adapt the loss value in the file name here
                ckpt_re = ".ckpt-%d" % (num_batches_per_epoch_train * epoch)
                for __file in files:
                    if __file.endswith(ckpt_re + ".index"):
                        ckpt = os.path.join(
                            cfg.logdir +
                            '/{}/{}/'.format(model_name, dataset_name),
                            __file[:-6])
                #ckpt = os.path.join(cfg.logdir, "model.ckpt-%d" % (num_batches_per_epoch_train * epoch))
                ############Comentar linea de abajo
                #ckpt = os.path.join(cfg.logdir, "caps/mnist/model-0.3764.ckpt-1718")
                saver.restore(sess, ckpt)

                accuracy_sum = 0
                for i in range(num_batches_test):
                    batch_acc_v, summary_str, orig_image, recon_image = sess.run(
                        [batch_acc, summary_op, orig_img, recon_img])
                    print('%d batches are tested.' % step)
                    summary_writer.add_summary(summary_str, step)

                    accuracy_sum += batch_acc_v

                    step += 1
                    # display original/reconstructed images in matplotlib
                    plot_imgs(orig_image, i, 'ori')
                    plot_imgs(recon_image, i, 'rec')

                ave_acc = accuracy_sum / num_batches_test
                print('the average accuracy is %f' % ave_acc)
Exemplo n.º 5
0
def main(args):
    """Get dataset hyperparameters."""
    assert len(args) == 3 and isinstance(args[1], str) and isinstance(args[2], str)
    dataset_name = args[1]
    model_name = args[2]
    coord_add = get_coord_add(dataset_name)
    dataset_size_train = get_dataset_size_train(dataset_name)
    dataset_size_test = get_dataset_size_test(dataset_name)
    num_classes = get_num_classes(dataset_name)
    create_inputs = get_create_inputs(
        dataset_name, is_train=False, epochs=cfg.epoch)

    """Set reproduciable random seed"""
    tf.set_random_seed(1234)

    with tf.Graph().as_default():
        num_batches_per_epoch_train = int(dataset_size_train / cfg.batch_size)
        num_batches_test = int(dataset_size_test / cfg.batch_size * 0.1)

        batch_x, batch_labels = create_inputs()
        batch_x = slim.batch_norm(batch_x, center=False, is_training=False, trainable=False)
        if model_name == "caps":
            output, _ = net.build_arch(batch_x, coord_add,
                                       is_train=False, num_classes=num_classes)
        elif model_name == "cnn_baseline":
            output = net.build_arch_baseline(batch_x,
                                             is_train=False, num_classes=num_classes)
        else:
            raise "Please select model from 'caps' or 'cnn_baseline' as the secondary argument of eval.py!"
        batch_acc = net.test_accuracy(output, batch_labels)
        saver = tf.train.Saver()

        step = 0

        summaries = []
        summaries.append(tf.summary.scalar('accuracy', batch_acc))
        summary_op = tf.summary.merge(summaries)

        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True, log_device_placement=False)) as sess:
            sess.run(tf.local_variables_initializer())
            sess.run(tf.global_variables_initializer())

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            if not os.path.exists(cfg.test_logdir + '/{}/{}/'.format(model_name, dataset_name)):
                os.makedirs(cfg.test_logdir + '/{}/{}/'.format(model_name, dataset_name))
            summary_writer = tf.summary.FileWriter(
                cfg.test_logdir + '/{}/{}/'.format(model_name, dataset_name), graph=sess.graph)  # graph=sess.graph, huge!

            files = os.listdir(cfg.logdir + '/{}/{}/'.format(model_name, dataset_name))
            for epoch in range(1, cfg.epoch):
                # requires a regex to adapt the loss value in the file name here
                ckpt_re = ".ckpt-%d" % (num_batches_per_epoch_train * epoch)
                for __file in files:
                    if __file.endswith(ckpt_re + ".index"):
                        ckpt = os.path.join(cfg.logdir + '/{}/{}/'.format(model_name, dataset_name), __file[:-6])
                # ckpt = os.path.join(cfg.logdir, "model.ckpt-%d" % (num_batches_per_epoch_train * epoch))
                saver.restore(sess, ckpt)

                accuracy_sum = 0
                for i in range(num_batches_test):
                    batch_acc_v, summary_str = sess.run([batch_acc, summary_op])
                    print('%d batches are tested.' % step)
                    summary_writer.add_summary(summary_str, step)

                    accuracy_sum += batch_acc_v

                    step += 1

                ave_acc = accuracy_sum / num_batches_test
                print('the average accuracy is %f' % ave_acc)

            coord.join(threads)
Exemplo n.º 6
0
def main(args):
    """Get dataset hyperparameters."""
    assert len(args) == 2 and isinstance(args[1], str)
    dataset_name = args[1]
    logger.info('Using dataset: {}'.format(dataset_name))
    coord_add = get_coord_add(dataset_name)
    num_classes = get_num_classes(dataset_name)

    dataset_size = get_dataset_size_train(dataset_name)
    dataset_size_test = get_dataset_size_test(dataset_name)
    create_inputs = get_create_inputs(dataset_name,
                                      is_train=True,
                                      epochs=cfg.epoch)
    test_inputs = get_create_inputs(dataset_name,
                                    is_train=False,
                                    epochs=cfg.epoch)
    """Set reproduciable random seed"""
    tf.set_random_seed(1234)

    with tf.Graph().as_default(), tf.device('/cpu:0'):
        """Get global_step."""
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)
        """Get batches per epoch."""
        num_batches_per_epoch = int(dataset_size / cfg.batch_size)
        num_batches_test = int(dataset_size_test / cfg.batch_size)
        """Set tf summaries."""
        summaries = []
        valid_sum = []
        """Use exponential decay leanring rate?"""
        # lrn_rate = tf.maximum(tf.train.exponential_decay(
        #     1e-2, global_step, num_batches_per_epoch, 0.8), 1e-5)
        # summaries.append(tf.summary.scalar('learning_rate', lrn_rate))
        opt = tf.train.AdamOptimizer(learning_rate=0.001)
        """Get batch from data queue."""
        train_q = create_inputs()
        test_q = test_inputs()
        use_train_data = tf.placeholder(dtype=tf.bool, shape=())
        batch_x, batch_labels = tf.cond(use_train_data,
                                        true_fn=lambda: train_q,
                                        false_fn=lambda: test_q)
        # batch_y = tf.one_hot(batch_labels, depth=10, axis=1, dtype=tf.float32)
        """Define the dataflow graph."""
        m_op = tf.placeholder(dtype=tf.float32, shape=())
        with tf.device('/gpu:0'):
            with slim.arg_scope([slim.variable], device='/cpu:0'):
                norm_batch_x = tf.contrib.layers.batch_norm(batch_x,
                                                            is_training=True)

                # Select network architecture.
                if cfg.network == 'conv':
                    import capsnet_em as net
                    output = net.build_arch(norm_batch_x,
                                            coord_add,
                                            is_train=True,
                                            num_classes=num_classes)
                elif cfg.network == 'fc':
                    import capsnet_fc as net
                    output = net.build_arch(norm_batch_x,
                                            is_train=True,
                                            num_classes=num_classes)
                else:
                    raise ValueError('Invalid network architecture: ' %
                                     cfg.network)

                # Select loss function.
                if cfg.loss_fn == 'spread':
                    loss = net.spread_loss(output, batch_labels, m_op)
                elif cfg.loss_fn == 'margin':
                    loss = net.margin_loss(output, batch_labels)
                elif cfg.loss_fn == 'cross_en':
                    loss = net.cross_entropy_loss(output, batch_labels)
                else:
                    raise ValueError('Invalid loss function: ' % cfg.loss_fn)

                acc = net.accuracy(output, batch_labels)
            """Compute gradient."""
            grad = opt.compute_gradients(loss)
            # See: https://stackoverflow.com/questions/40701712/how-to-check-nan-in-gradients-in-tensorflow-when-updating
            grad_check = [tf.check_numerics(g, message='Gradient NaN Found!')
                          for g, _ in grad] + \
                         [tf.check_numerics(loss, message='Loss NaN Found')]
        """Add to summary."""
        summaries.append(tf.summary.scalar('loss', loss))
        summaries.append(tf.summary.scalar('acc', acc))
        valid_sum.append(tf.summary.scalar('val_loss', loss))
        valid_sum.append(tf.summary.scalar('val_acc', acc))
        """Apply graident."""
        with tf.control_dependencies(grad_check):
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = opt.apply_gradients(grad, global_step=global_step)
        """Set Session settings."""
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=cfg.gpu_frac)
        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                                log_device_placement=False,
                                                gpu_options=gpu_options))
        sess.run(tf.local_variables_initializer())
        sess.run(tf.global_variables_initializer())
        """Set Saver."""
        var_to_save = [
            v for v in tf.global_variables() if 'Adam' not in v.name
        ]  # Don't save redundant Adam beta/gamma
        saver = tf.train.Saver(var_list=var_to_save, max_to_keep=cfg.epoch)
        """Display parameters"""
        total_p = np.sum([
            np.prod(v.get_shape().as_list()) for v in var_to_save
        ]).astype(np.int32)
        train_p = np.sum([
            np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()
        ]).astype(np.int32)
        logger.info('Total Parameters: {}'.format(total_p))
        logger.info('Trainable Parameters: {}'.format(train_p))

        # read snapshot
        # latest = os.path.join(cfg.logdir, 'model.ckpt-4680')
        # saver.restore(sess, latest)
        """Set summary op."""
        summary_op = tf.summary.merge(summaries)
        valid_sum_op = tf.summary.merge(valid_sum)
        """Start coord & queue."""
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        """Set summary writer"""
        summary_writer = tf.summary.FileWriter(
            cfg.logdir, graph=None)  # graph = sess.graph, huge!
        """Main loop."""
        m_min = 0.2
        m_max = 0.9
        m = m_min
        for step in range(cfg.epoch * num_batches_per_epoch):
            if (step % num_batches_per_epoch) == 0:
                tic = time.time()
                progbar = tf.keras.utils.Progbar(
                    num_batches_per_epoch, verbose=(1 if cfg.progbar else 0))
            """"TF queue would pop batch until no file"""
            try:
                _, loss_value, acc_value = sess.run([train_op, loss, acc],
                                                    feed_dict={
                                                        use_train_data: True,
                                                        m_op: m
                                                    })
                progbar.update((step % num_batches_per_epoch),
                               values=[('loss', loss_value),
                                       ('acc', acc_value)])
            except KeyboardInterrupt:
                sess.close()
                sys.exit()
            except tf.errors.InvalidArgumentError:
                logger.warning(
                    '%d iteration contains NaN gradients. Discard.' % step)
                continue
            """Write to summary."""
            if step % 10 == 0:
                summary_str = sess.run(summary_op,
                                       feed_dict={
                                           use_train_data: True,
                                           m_op: m
                                       })
                summary_writer.add_summary(summary_str, step)
            """Epoch wise linear annealling."""
            if (step % num_batches_per_epoch) == 0:
                if step > 0:
                    m += (m_max - m_min) / (cfg.epoch * 0.6)
                    if m > m_max:
                        m = m_max
                """Save model periodically"""
                ckpt_path = os.path.join(
                    cfg.logdir, 'model-{0:.4f}.ckpt'.format(loss_value))
                saver.save(sess, ckpt_path, global_step=step)

            # Add a new progress bar
            if ((step + 1) % num_batches_per_epoch) == 0:
                toc = time.time()
                val_loss_value, val_acc_value = (0.0, 0.0)
                for i in range(num_batches_test):
                    val_batch = sess.run([loss, acc],
                                         feed_dict={
                                             use_train_data: False,
                                             m_op: m
                                         })
                    val_loss_batch, val_acc_batch = val_batch
                    val_loss_value += val_loss_batch / num_batches_test
                    val_acc_value += val_acc_batch / num_batches_test
                valid_sum_str = sess.run(valid_sum_op,
                                         feed_dict={
                                             use_train_data: False,
                                             m_op: m
                                         })
                summary_writer.add_summary(valid_sum_str, step)
                print('\nEpoch %d/%d in ' %
                      (step // num_batches_per_epoch + 1, cfg.epoch) +
                      '%.1fs' % (toc - tic) + ' - loss: %f' % val_loss_value +
                      ' - acc: %f' % val_acc_value)
        """Join threads"""
        coord.join(threads)
Exemplo n.º 7
0
def main(args):
    """Get dataset hyperparameters."""
    assert len(args) == 2 and isinstance(args[1], str)
    dataset_name = args[1]
    coord_add = get_coord_add(dataset_name)
    dataset_size_train = get_dataset_size_train(dataset_name)
    dataset_size_test = get_dataset_size_test(dataset_name)
    num_classes = get_num_classes(dataset_name)
    create_inputs = get_create_inputs(dataset_name,
                                      is_train=False,
                                      epochs=cfg.epoch)
    """Set reproduciable random seed"""
    tf.set_random_seed(1234)

    with tf.Graph().as_default():
        num_batches_per_epoch_train = int(dataset_size_train / cfg.batch_size)
        num_batches_test = int(dataset_size_test / cfg.batch_size * 0.1)

        batch_x, batch_labels = create_inputs()
        batch_x = slim.batch_norm(batch_x,
                                  center=False,
                                  is_training=False,
                                  trainable=False)
        output, _ = net.build_arch(batch_x,
                                   coord_add,
                                   is_train=False,
                                   num_classes=num_classes)
        batch_acc = net.test_accuracy(output, batch_labels)
        saver = tf.train.Saver()

        step = 0

        summaries = []
        summaries.append(tf.summary.scalar('accuracy', batch_acc))
        summary_op = tf.summary.merge(summaries)

        with tf.Session(
                config=tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=False)) as sess:
            sess.run(tf.local_variables_initializer())
            sess.run(tf.global_variables_initializer())

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            summary_writer = tf.summary.FileWriter(
                cfg.test_logdir, graph=sess.graph)  # graph=sess.graph, huge!

            files = os.listdir(cfg.logdir)
            for epoch in range(1, cfg.epoch):
                # requires a regex to adapt the loss value in the file name here
                ckpt_re = ".ckpt-%d" % (num_batches_per_epoch_train * epoch)
                for __file in files:
                    if __file.endswith(ckpt_re + ".index"):
                        ckpt = os.path.join(cfg.logdir, __file[:-6])
                # ckpt = os.path.join(cfg.logdir, "model.ckpt-%d" % (num_batches_per_epoch_train * epoch))
                saver.restore(sess, ckpt)

                accuracy_sum = 0
                for i in range(num_batches_test):
                    batch_acc_v, summary_str = sess.run(
                        [batch_acc, summary_op])
                    print('%d batches are tested.' % step)
                    summary_writer.add_summary(summary_str, step)
                    print('%d batch accuracy.' % batch_acc_v)

                    accuracy_sum += batch_acc_v

                    step += 1

                ave_acc = accuracy_sum / num_batches_test
                print('the average accuracy is %f' % ave_acc)

            coord.join(threads)
Exemplo n.º 8
0
def main(args):
    """Get dataset hyperparameters."""
    assert len(args) == 3 and isinstance(args[1], str) and isinstance(
        args[2], str)
    dataset_name = args[1]
    model_name = args[2]
    coord_add = get_coord_add(dataset_name)
    dataset_size_train = get_dataset_size_train(dataset_name)
    dataset_size_test = get_dataset_size_test(dataset_name)
    num_classes = get_num_classes(dataset_name)
    create_inputs = get_create_inputs(dataset_name,
                                      is_train=False,
                                      epochs=cfg.epoch)
    """Set reproduciable random seed"""
    tf.set_random_seed(1234)

    with tf.Graph().as_default():
        num_batches_per_epoch_train = int(dataset_size_train / cfg.batch_size)
        num_batches_test = int(dataset_size_test / cfg.batch_size * 0.1)

        batch_x, batch_labels = create_inputs()
        batch_x = slim.batch_norm(batch_x,
                                  center=False,
                                  is_training=False,
                                  trainable=False)
        if model_name == "caps":
            output, _ = net.build_arch(batch_x,
                                       coord_add,
                                       is_train=False,
                                       num_classes=num_classes)
        elif model_name == "cnn_baseline":
            output = net.build_arch_baseline(batch_x,
                                             is_train=False,
                                             num_classes=num_classes)
        else:
            raise "Please select model from 'caps' or 'cnn_baseline' as the secondary argument of eval.py!"
        batch_acc = net.test_accuracy(output, batch_labels)
        saver = tf.train.Saver()

        step = 0

        summaries = []
        summaries.append(tf.summary.scalar('accuracy', batch_acc))
        summary_op = tf.summary.merge(summaries)

        session_config = tf.ConfigProto(
            device_count={'GPU': 0},
            gpu_options={
                'allow_growth': 1,
                # 'per_process_gpu_memory_fraction': 0.1,
                'visible_device_list': '0'
            },
            allow_soft_placement=True)
        with tf.Session(config=session_config) as sess:
            sess.run(tf.local_variables_initializer())
            sess.run(tf.global_variables_initializer())

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            if not os.path.exists(cfg.test_logdir +
                                  '/{}/{}/'.format(model_name, dataset_name)):
                os.makedirs(cfg.test_logdir +
                            '/{}/{}/'.format(model_name, dataset_name))
            summary_writer = tf.summary.FileWriter(
                cfg.test_logdir + '/{}/{}/'.format(model_name, dataset_name),
                graph=sess.graph)  # graph=sess.graph, huge!

            files = os.listdir(cfg.logdir +
                               '/{}/{}/'.format(model_name, dataset_name))
            for epoch in range(1, cfg.epoch):
                # requires a regex to adapt the loss value in the file name here
                ckpt_re = ".ckpt-%d" % (num_batches_per_epoch_train * epoch)
                for __file in files:
                    if __file.endswith(ckpt_re + ".index"):
                        ckpt = os.path.join(
                            cfg.logdir +
                            '/{}/{}/'.format(model_name, dataset_name),
                            __file[:-6])
                # ckpt = os.path.join(cfg.logdir, "model.ckpt-%d" % (num_batches_per_epoch_train * epoch))
                saver.restore(sess, ckpt)

                accuracy_sum = 0
                for i in range(num_batches_test):
                    batch_acc_v, summary_str = sess.run(
                        [batch_acc, summary_op])
                    print('%d batches are tested.' % step)
                    summary_writer.add_summary(summary_str, step)

                    accuracy_sum += batch_acc_v

                    step += 1

                ave_acc = accuracy_sum / num_batches_test
                print('the average accuracy is %f' % ave_acc)

            coord.join(threads)
def main(args):
    """Get dataset hyperparameters."""
    assert len(args) == 2 and isinstance(args[1], str)
    dataset_name = args[1]
    logger.info('Using dataset: {}'.format(dataset_name))
    coord_add = get_coord_add(dataset_name)
    dataset_size = get_dataset_size_train(dataset_name)
    num_classes = get_num_classes(dataset_name)
    create_inputs = get_create_inputs(dataset_name,
                                      is_train=True,
                                      epochs=cfg.epoch)
    """Set reproduciable random seed"""
    tf.set_random_seed(1234)

    with tf.Graph().as_default(), tf.device('/cpu:0'):
        """Get global_step."""
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)
        """Get batches per epoch."""
        num_batches_per_epoch = int(dataset_size / cfg.batch_size)
        """Set tf summaries."""
        summaries = []
        """Use exponential decay leanring rate?"""
        lrn_rate = tf.maximum(
            tf.train.exponential_decay(1e-3, global_step, 2e2, 0.66), 1e-5)
        summaries.append(tf.summary.scalar('learning_rate', lrn_rate))
        opt = tf.train.AdamOptimizer(lrn_rate)
        """Get batch from data queue."""
        batch_x, batch_labels = create_inputs()
        # batch_y = tf.one_hot(batch_labels, depth=10, axis=1, dtype=tf.float32)
        """Define the dataflow graph."""
        m_op = tf.placeholder(dtype=tf.float32, shape=())
        with tf.device('/gpu:0'):
            with slim.arg_scope([slim.variable], device='/cpu:0'):
                output = net.build_arch(batch_x,
                                        coord_add,
                                        is_train=True,
                                        num_classes=num_classes)
                # loss = net.cross_ent_loss(output, batch_labels)
                loss = net.spread_loss(output, batch_labels, m_op)
            """Compute gradient."""
            grad = opt.compute_gradients(loss)
        """Add loss to summary."""
        summaries.append(tf.summary.scalar('spread_loss', loss))
        """Apply graident."""
        train_op = opt.apply_gradients(grad, global_step=global_step)
        """Set Session settings."""
        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                                log_device_placement=False))
        sess.run(tf.local_variables_initializer())
        sess.run(tf.global_variables_initializer())
        """Set Saver."""
        var_to_save = [
            v for v in tf.global_variables() if 'Adam' not in v.name
        ]  # Don't save redundant Adam beta/gamma
        saver = tf.train.Saver(var_list=var_to_save, max_to_keep=cfg.epoch)
        """Display parameters"""
        total_p = np.sum([
            np.prod(v.get_shape().as_list()) for v in var_to_save
        ]).astype(np.int32)
        train_p = np.sum([
            np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()
        ]).astype(np.int32)
        logger.info('Total Parameters: {}'.format(total_p))
        logger.info('Trainable Parameters: {}'.format(train_p))

        # read snapshot
        # latest = os.path.join(cfg.logdir, 'model.ckpt-4680')
        # saver.restore(sess, latest)
        """Set summary op."""
        summary_op = tf.summary.merge(summaries)
        """Start coord & queue."""
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        """Set summary writer"""
        summary_writer = tf.summary.FileWriter(
            cfg.logdir, graph=None)  # graph = sess.graph, huge!
        """Main loop."""
        m_min = 0.2
        m_max = 0.9
        m = m_min
        for step in range(cfg.epoch * num_batches_per_epoch):
            tic = time.time()
            """"TF queue would pop batch until no file"""
            _, loss_value = sess.run([train_op, loss], feed_dict={m_op: m})
            logger.info('%d iteration finishs in ' % step + '%f second' %
                        (time.time() - tic) + ' loss=%f' % loss_value)
            """Check NaN"""
            assert not np.isnan(loss_value), 'loss is nan'
            """Write to summary."""
            if step % 10 == 0:
                summary_str = sess.run(summary_op, feed_dict={m_op: m})
                summary_writer.add_summary(summary_str, step)
            """Epoch wise linear annealling."""
            if (step % num_batches_per_epoch) == 0:
                if step > 0:
                    m += (m_max - m_min) / (cfg.epoch * 0.6)
                    if m > m_max:
                        m = m_max
                """Save model periodically"""
                ckpt_path = os.path.join(
                    cfg.logdir, 'model-{}.ckpt'.format(round(loss_value, 4)))
                saver.save(sess, ckpt_path, global_step=step)
        """Join threads"""
        coord.join(threads)
Exemplo n.º 10
0
def main(args):
    """Get dataset hyperparameters."""
    assert len(args) == 3 and isinstance(args[1], str) and isinstance(args[2], str)
    dataset_name = args[1]
    model_name = args[2]

    """Set reproduciable random seed"""
    tf.set_random_seed(1234)

    coord_add = get_coord_add(dataset_name)
    dataset_size_train = get_dataset_size_train(dataset_name)
    dataset_size_test = get_dataset_size_test(dataset_name)
    num_classes = get_num_classes(dataset_name)
    create_inputs = get_create_inputs(
        dataset_name, is_train=False, epochs=cfg.epoch)

    with tf.Graph().as_default():
        num_batches_per_epoch_train = int(dataset_size_train / cfg.batch_size)
        num_batches_test = 2  # int(dataset_size_test / cfg.batch_size * 0.1)

        batch_x, batch_labels = create_inputs()
        batch_squash = tf.divide(batch_x, 255.)
        batch_x_norm = slim.batch_norm(batch_x, center=False, is_training=False, trainable=False)
        output, pose_out = net.build_arch(batch_x_norm, coord_add,
                                          is_train=False, num_classes=num_classes)
        tf.logging.debug(pose_out.get_shape())

        batch_acc = net.test_accuracy(output, batch_labels)
        m_op = tf.constant(0.9)
        loss, spread_loss, mse, recon_img_squash = net.spread_loss(
            output, pose_out, batch_squash, batch_labels, m_op)
        tf.summary.scalar('spread_loss', spread_loss)
        tf.summary.scalar('reconstruction_loss', mse)
        tf.summary.scalar('all_loss', loss)
        data_size = int(batch_x.get_shape()[1])
        recon_img = tf.multiply(tf.reshape(recon_img_squash, shape=[
                                cfg.batch_size, data_size, data_size, 1]), 255.)
        orig_img = tf.reshape(batch_x, shape=[
            cfg.batch_size, data_size, data_size, 1])
        tf.summary.image('orig_image', orig_img)
        tf.summary.image('recon_image', recon_img)
        saver = tf.train.Saver()

        step = 0

        tf.summary.scalar('accuracy', batch_acc)
        summary_op = tf.summary.merge_all()

        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True, log_device_placement=False)) as sess:
            sess.run(tf.local_variables_initializer())
            sess.run(tf.global_variables_initializer())

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            if not os.path.exists(cfg.test_logdir + '/{}/{}/'.format(model_name, dataset_name)):
                os.makedirs(cfg.test_logdir + '/{}/{}/'.format(model_name, dataset_name))
            summary_writer = tf.summary.FileWriter(
                cfg.test_logdir + '/{}/{}/'.format(model_name, dataset_name), graph=sess.graph)  # graph=sess.graph, huge!

            files = os.listdir(cfg.logdir + '/{}/{}/'.format(model_name, dataset_name))
            for epoch in range(14, 15):
                # requires a regex to adapt the loss value in the file name here
                ckpt_re = ".ckpt-%d" % (num_batches_per_epoch_train * epoch)
                for __file in files:
                    if __file.endswith(ckpt_re + ".index"):
                        ckpt = os.path.join(
                            cfg.logdir + '/{}/{}/'.format(model_name, dataset_name), __file[:-6])
                # ckpt = os.path.join(cfg.logdir, "model.ckpt-%d" % (num_batches_per_epoch_train * epoch))
                saver.restore(sess, ckpt)

                accuracy_sum = 0
                for i in range(num_batches_test):
                    batch_acc_v, summary_str, orig_image, recon_image = sess.run(
                        [batch_acc, summary_op, orig_img, recon_img])
                    print('%d batches are tested.' % step)
                    summary_writer.add_summary(summary_str, step)

                    accuracy_sum += batch_acc_v

                    step += 1
                    # display original/reconstructed images in matplotlib
                    plot_imgs(orig_image, i, 'ori')
                    plot_imgs(recon_image, i, 'rec')

                ave_acc = accuracy_sum / num_batches_test
                print('the average accuracy is %f' % ave_acc)
Exemplo n.º 11
0
def test_model(n_tests, x_test, y_test, ang_min, ang_max):

    # Placeholders for input data and the targets
    x_input = tf.placeholder(tf.float32, (None, *IMG_DIM), name='Input')
    y_target = tf.placeholder(tf.int32, [None, ], name='Target')

    coord_add = get_coord_add(dataset_name )
    sample_batch = tf.identity(x_input)
    batch_labels = tf.identity(y_target)
    batch_x = slim.batch_norm(sample_batch, center=False, is_training=False, trainable=False)
    output, pose_out = net.build_arch(batch_x, coord_add, is_train=True,
                                      num_classes=NCLASSES)
    batch_acc_sum = net.test_accuracy_sum(output, batch_labels)
    batch_pred = net.test_predict(output, batch_labels)

    saver = tf.train.Saver()

    sess = tf.Session()
    sess.run(tf.local_variables_initializer())
    sess.run(tf.global_variables_initializer())

    model_path = cfg.logdir + '/caps/mnist'
    saver.restore(sess, tf.train.latest_checkpoint(model_path))

    nImg = x_test.shape[0]
    batch_size = int(cfg.batch_size)
    nBatches = int(nImg / batch_size)

    accuraces = []

    mean_acc = 0
    for n in range(n_tests):
        print('\nTest %d/%d' % (n + 1, n_tests))

        print('-' * 30 + 'Begin: testing' + '-' * 30)
        acc = 0
        k = 0
        xi = np.empty([1, sy, sx, 1])
        x_init = np.empty([1, sy, sx, 1])

        for i in range(nBatches):
            x = x_test[i * batch_size: (i + 1) * batch_size, :, :, :]
            y = y_test[i * batch_size: (i + 1) * batch_size]
            xr = np.empty(x.shape)
            for j in range(x.shape[0]):
                xr[j, :, :, :] = utils.create_inputs_mnist_rot_excl_range(x[j, :, :, :], y[j],
                                                                          ang_min, ang_max)

                k += 1

            batch_acc_v = sess.run(batch_acc_sum, feed_dict={x_input: xr, y_target: y})
            acc += batch_acc_v

            # Just checking what images we are feeding to the network
            if i == 0 and n == 0:
                for j in range(batch_size):
                    if j == 0:
                        xi[0, :, :, :] = xr[0, :, :, :]
                        x_init[0, :, :, :] = x[0, :, :, :]
                    else:
                        xi = np.concatenate([xi, np.expand_dims(xr[j, :, :, :], 0)])
                        x_init = np.concatenate([x_init, np.expand_dims(x[j, :, :, :],0)])
                    # xr = np.concatenate([xr, x_recon])
                    if j == (batch_size - 1):
                        images = utils.combine_images(xi)
                        image = images
                        Image.fromarray(image.astype(np.uint8)).save(cfg.logdir + "/batch_rot.png")

                        images = utils.combine_images(x_init)
                        image = images
                        Image.fromarray(image.astype(np.uint8)).save(cfg.logdir + "/batch_init.png")

            sys.stdout.write(ERASE_LINE)
            sys.stdout.write("\r \r {0}%".format(int(100 * k / nImg)))
            sys.stdout.flush()
            time.sleep(0.001)


        x = x_test[k:, :, :, :]
        y = y_test[k:]

        # duplicate the last sample to adjust the batch size
        n_left = nImg-k
        n_tile = BATCH_SIZE - n_left

        x_tile = np.tile(np.expand_dims(x_test[nImg-1, :, :, :],0), [n_tile, 1, 1, 1])
        y_tile = np.tile(y_test[nImg-1], n_tile)

        x = np.concatenate( (x, x_tile) )
        y = np.concatenate((y, y_tile))

        xr = np.empty(x.shape)
        for j in range(x.shape[0]):
            xr[j, :, :, :] = utils.create_inputs_mnist_rot_excl_range(x[j, :, :, :], y[j],
                                                                      ang_min, ang_max)

        batch_pred_v = sess.run(batch_pred, feed_dict={x_input: xr, y_target: y})
        left_pred = np.asarray(batch_pred_v[:n_left], dtype=np.float32)

        acc += np.sum(left_pred)

        k += n_left

        sys.stdout.write(ERASE_LINE)
        sys.stdout.write("\r \r {0}%".format(str(100)))
        sys.stdout.flush()
        time.sleep(0.001)

        print('\n')
        print('-' * 30 + 'End: testing' + '-' * 30)

        acc_aver = acc / float(y_test.shape[0])

        print('Number of images: {}, Accuracy: {}'.format(k, acc_aver))

        mean_acc += acc_aver
        accuraces.append(acc_aver)

    mean_acc = mean_acc / float(n_tests)

    var_acc = 0
    accuraces = np.array(accuraces)
    for i in range(accuraces.shape[0]):
        var_acc += (accuraces[i] - mean_acc)*(accuraces[i] - mean_acc)

    var_acc /= float(n_tests)

    print('\nTesting is finished!')
    print('Testing options:\nAngles range from {} to {}\tIs only 3 and 4: {}'.format(ang_min, ang_max, is_only_3_and_4))
    print('\nMean testing accuracy for {} runs: {}'.format(n_tests, mean_acc))
    print('Variance of testing accuracy for {} runs: {}'.format(n_tests, var_acc))
Exemplo n.º 12
0
def main(args):
    assert len(args) == 2 and isinstance(args[1], str)

    # Get dataset name
    dataset_name = args[1]  # mnist
    logger.info(f'Using dataset: {dataset_name}')

    # Set reproducible random seed
    tf.set_random_seed(1234)

    coord_add = get_coord_add(dataset_name)  # (3, 3, 2)
    dataset_size = get_dataset_size_train(dataset_name)  # 55,000
    num_classes = get_num_classes(dataset_name)  # 10
    create_inputs = get_create_inputs(dataset_name,
                                      is_train=True,
                                      epochs=cfg.epoch)

    with tf.Graph().as_default(), tf.device('/cpu:0'):
        # Get global_step
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)

        num_batches_per_epoch = dataset_size // cfg.batch_size  # 1100

        opt = tf.train.AdamOptimizer()

        # Get batch from data queue
        batch_x, batch_labels = create_inputs()  # (50 28, 28, 1), (50,)

        m_op = tf.placeholder(dtype=tf.float32, shape=())
        with tf.device('/gpu:0'):
            with slim.arg_scope([slim.variable], device='/cpu:0'):
                batch_squash = tf.divide(batch_x, 255.)
                batch_x = slim.batch_norm(batch_x,
                                          center=False,
                                          is_training=True,
                                          trainable=True)

                output, pose_out = net.build_arch(
                    batch_x, coord_add, is_train=True,
                    num_classes=num_classes)  # (50, 10), (50, 10, 18)
                tf.logging.debug(pose_out.get_shape())

                # Define loss = spread_loss + reconstruction loss
                loss, spread_loss, mse, _ = net.spread_loss(
                    output, pose_out, batch_squash, batch_labels, m_op)

                acc = net.test_accuracy(output, batch_labels)
                tf.summary.scalar('spread_loss', spread_loss)
                tf.summary.scalar('reconstruction_loss', mse)
                tf.summary.scalar('all_loss', loss)
                tf.summary.scalar('train_acc', acc)

            grad = opt.compute_gradients(loss)
            # See: https://stackoverflow.com/questions/40701712/how-to-check-nan-in-gradients-in-tensorflow-when-updating
            grad_check = [
                tf.check_numerics(g, message='Gradient NaN Found!')
                for g, _ in grad if g is not None
            ] + [tf.check_numerics(loss, message='Loss NaN Found')]

        # Apply graident
        with tf.control_dependencies(grad_check):
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = opt.apply_gradients(grad, global_step=global_step)

        # Set Session settings
        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                                log_device_placement=False))
        sess.run(tf.local_variables_initializer())
        sess.run(tf.global_variables_initializer())

        # Set Saver
        var_to_save = [
            v for v in tf.global_variables() if 'Adam' not in v.name
        ]  # Don't save redundant Adam beta/gamma
        saver = tf.train.Saver(var_list=var_to_save, max_to_keep=cfg.epoch)

        # Display parameters
        total_p = np.sum([
            np.prod(v.get_shape().as_list()) for v in var_to_save
        ]).astype(np.int32)
        train_p = np.sum([
            np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()
        ]).astype(np.int32)
        logger.info('Total Parameters: {}'.format(total_p))
        logger.info('Trainable Parameters: {}'.format(train_p))

        # read snapshot
        # latest = os.path.join(cfg.logdir, 'model.ckpt-4680')
        # saver.restore(sess, latest)

        # Set summary op
        summary_op = tf.summary.merge_all()

        # Start coord & queue
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        # Set summary writer
        if not os.path.exists(cfg.logdir +
                              '/caps/{}/train_log/'.format(dataset_name)):
            os.makedirs(cfg.logdir +
                        '/caps/{}/train_log/'.format(dataset_name))
        summary_writer = tf.summary.FileWriter(
            cfg.logdir + f"/caps/{dataset_name}/train_log/",
            graph=sess.graph)  # graph = sess.graph, huge!

        # Main loop
        m_min = 0.2
        m_max = 0.9
        m = m_min
        for step in range(cfg.epoch * num_batches_per_epoch + 1):
            tic = time.time()
            # TF queue would pop batch until no file
            try:
                _, loss_value, summary_str = sess.run(
                    [train_op, loss, summary_op], feed_dict={m_op: m})
                logger.info('%d iteration finishs in ' % step + '%f second' %
                            (time.time() - tic) + ' loss=%f' % loss_value)
            except KeyboardInterrupt:
                sess.close()
                sys.exit()
            except tf.errors.InvalidArgumentError:
                logger.warning(
                    '%d iteration contains NaN gradients. Discard.' % step)
                continue
            else:
                if step % 5 == 0:
                    summary_writer.add_summary(summary_str, step)
                """Epoch wise linear annealling."""
                if (step % num_batches_per_epoch) == 0:
                    if step > 0:
                        m += (m_max - m_min) / (cfg.epoch * cfg.m_schedule)
                        if m > m_max:
                            m = m_max

                    # Save model periodically
                    ckpt_path = os.path.join(
                        cfg.logdir + '/caps/{}/'.format(dataset_name),
                        'model-{:.4f}.ckpt'.format(loss_value))
                    saver.save(sess, ckpt_path, global_step=step)
Exemplo n.º 13
0
def main(args):
    assert len(args) == 2 and isinstance(args[1], str)

    # Get dataset name
    dataset_name = args[1]   # mnist
    logger.info(f'Using dataset: {dataset_name}')

    # Set reproducible random seed
    tf.set_random_seed(1234)

    coord_add = get_coord_add(dataset_name)             # (3, 3, 2)
    dataset_size = get_dataset_size_train(dataset_name) # 55,000
    num_classes = get_num_classes(dataset_name)         # 10
    create_inputs = get_create_inputs(dataset_name, is_train=True, epochs=cfg.epoch)

    with tf.Graph().as_default(), tf.device('/cpu:0'):
        # Get global_step
        global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False)

        num_batches_per_epoch = dataset_size // cfg.batch_size # 1100

        opt = tf.train.AdamOptimizer()

        # Get batch from data queue
        batch_x, batch_labels = create_inputs() # (50 28, 28, 1), (50,)

        m_op = tf.placeholder(dtype=tf.float32, shape=())
        with tf.device('/gpu:0'):
            with slim.arg_scope([slim.variable], device='/cpu:0'):
                batch_squash = tf.divide(batch_x, 255.)
                batch_x = slim.batch_norm(batch_x, center=False, is_training=True, trainable=True)

                output, pose_out = net.build_arch(batch_x, coord_add, is_train=True, num_classes=num_classes) # (50, 10), (50, 10, 18)
                tf.logging.debug(pose_out.get_shape())

                # Define loss = spread_loss + reconstruction loss
                loss, spread_loss, mse, _ = net.spread_loss(output, pose_out, batch_squash, batch_labels, m_op)

                acc = net.test_accuracy(output, batch_labels)
                tf.summary.scalar('spread_loss', spread_loss)
                tf.summary.scalar('reconstruction_loss', mse)
                tf.summary.scalar('all_loss', loss)
                tf.summary.scalar('train_acc', acc)

            grad = opt.compute_gradients(loss)
            # See: https://stackoverflow.com/questions/40701712/how-to-check-nan-in-gradients-in-tensorflow-when-updating
            grad_check = [tf.check_numerics(g, message='Gradient NaN Found!')
                          for g, _ in grad if g is not None] + [tf.check_numerics(loss, message='Loss NaN Found')]

        # Apply graident
        with tf.control_dependencies(grad_check):
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = opt.apply_gradients(grad, global_step=global_step)

        # Set Session settings
        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))
        sess.run(tf.local_variables_initializer())
        sess.run(tf.global_variables_initializer())

        # Set Saver
        var_to_save = [v for v in tf.global_variables() if 'Adam' not in v.name]  # Don't save redundant Adam beta/gamma
        saver = tf.train.Saver(var_list=var_to_save, max_to_keep=cfg.epoch)

        # Display parameters
        total_p = np.sum([np.prod(v.get_shape().as_list()) for v in var_to_save]).astype(np.int32)
        train_p = np.sum([np.prod(v.get_shape().as_list())
                          for v in tf.trainable_variables()]).astype(np.int32)
        logger.info('Total Parameters: {}'.format(total_p))
        logger.info('Trainable Parameters: {}'.format(train_p))

        # read snapshot
        # latest = os.path.join(cfg.logdir, 'model.ckpt-4680')
        # saver.restore(sess, latest)

        # Set summary op
        summary_op = tf.summary.merge_all()

        # Start coord & queue
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        # Set summary writer
        if not os.path.exists(cfg.logdir + '/caps/{}/train_log/'.format(dataset_name)):
            os.makedirs(cfg.logdir + '/caps/{}/train_log/'.format(dataset_name))
        summary_writer = tf.summary.FileWriter(
            cfg.logdir + f"/caps/{dataset_name}/train_log/", graph=sess.graph)  # graph = sess.graph, huge!

        # Main loop
        m_min = 0.2
        m_max = 0.9
        m = m_min
        for step in range(cfg.epoch * num_batches_per_epoch + 1):
            tic = time.time()
            # TF queue would pop batch until no file
            try:
                _, loss_value, summary_str = sess.run([train_op, loss, summary_op], feed_dict={m_op: m})
                logger.info('%d iteration finishs in ' % step + '%f second' %
                            (time.time() - tic) + ' loss=%f' % loss_value)
            except KeyboardInterrupt:
                sess.close()
                sys.exit()
            except tf.errors.InvalidArgumentError:
                logger.warning('%d iteration contains NaN gradients. Discard.' % step)
                continue
            else:
                if step % 5 == 0:
                    summary_writer.add_summary(summary_str, step)

                """Epoch wise linear annealling."""
                if (step % num_batches_per_epoch) == 0:
                    if step > 0:
                        m += (m_max - m_min) / (cfg.epoch * cfg.m_schedule)
                        if m > m_max:
                            m = m_max

                    # Save model periodically
                    ckpt_path = os.path.join(cfg.logdir + '/caps/{}/'.format(dataset_name), 'model-{:.4f}.ckpt'.format(loss_value))
                    saver.save(sess, ckpt_path, global_step=step)
Exemplo n.º 14
0
def main(args):
    """Get dataset hyperparameters."""
    dataset_name = 'mnist'
    logger.info('Using dataset: {}'.format(dataset_name))
    """Set reproducible random seed"""
    tf.set_random_seed(1234)

    coord_add = get_coord_add(dataset_name)
    dataset_size = get_dataset_size_train(dataset_name)
    num_classes = get_num_classes(dataset_name)
    create_inputs = get_create_inputs(dataset_name,
                                      is_train=True,
                                      epochs=cfg.epoch)

    with tf.Graph().as_default(), tf.device('/cpu:0'):
        """Get global_step."""
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)
        """Get batches per epoch."""
        num_batches_per_epoch = int(dataset_size / cfg.batch_size)
        """Use exponential decay leaning rate?"""
        lrn_rate = tf.maximum(
            tf.train.exponential_decay(1e-3, global_step,
                                       num_batches_per_epoch, 0.8), 1e-5)
        tf.summary.scalar('learning_rate', lrn_rate)
        opt = tf.train.AdamOptimizer()  # lrn_rate
        """Get batch from data queue."""
        batch_x, batch_labels = create_inputs()
        """Define the dataflow graph."""
        m_op = tf.placeholder(dtype=tf.float32, shape=())

        with tf.device('/gpu:0'):
            with slim.arg_scope([slim.variable], device='/cpu:0'):
                batch_squash = tf.divide(batch_x, 255.)
                batch_x = slim.batch_norm(batch_x,
                                          center=False,
                                          is_training=True,
                                          trainable=True)
                output, pose_out = net.build_arch(batch_x,
                                                  coord_add,
                                                  is_train=True,
                                                  num_classes=num_classes)
                tf.logging.debug(pose_out.get_shape())
                loss, spread_loss, mse, _ = net.spread_loss(
                    output, pose_out, batch_squash, batch_labels, m_op)
                acc = net.test_accuracy(output, batch_labels)
                tf.summary.scalar('spread_loss', spread_loss)
                tf.summary.scalar('reconstruction_loss', mse)
                tf.summary.scalar('all_loss', loss)
                tf.summary.scalar('train_acc', acc)
            """Compute gradient."""
            grad = opt.compute_gradients(loss)
            grad_check = [
                tf.check_numerics(g, message='Gradient NaN Found!')
                for g, _ in grad if g is not None
            ] + [tf.check_numerics(loss, message='Loss NaN Found')]
        """Apply gradient."""
        with tf.control_dependencies(grad_check):
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = opt.apply_gradients(grad, global_step=global_step)
        """Set Session settings."""
        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                                log_device_placement=False))
        sess.run(tf.local_variables_initializer())
        sess.run(tf.global_variables_initializer())
        """Set Saver."""
        var_to_save = [
            v for v in tf.global_variables() if 'Adam' not in v.name
        ]
        saver = tf.train.Saver(var_list=var_to_save, max_to_keep=cfg.epoch)
        """Display parameters"""
        total_p = np.sum([
            np.prod(v.get_shape().as_list()) for v in var_to_save
        ]).astype(np.int32)
        train_p = np.sum([
            np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()
        ]).astype(np.int32)
        logger.info('Total Parameters: {}'.format(total_p))
        logger.info('Trainable Parameters: {}'.format(train_p))

        # read snapshot
        # latest = os.path.join(cfg.logdir, 'model.ckpt-4680')
        # saver.restore(sess, latest)
        """Set summary op."""
        summary_op = tf.summary.merge_all()
        """Start coord & queue."""
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        """Set summary writer"""
        if not os.path.exists(cfg.logdir +
                              '/caps/{}/train_log/'.format(dataset_name)):
            os.makedirs(cfg.logdir +
                        '/caps/{}/train_log/'.format(dataset_name))
        summary_writer = tf.summary.FileWriter(
            cfg.logdir + '/caps/{}/train_log/'.format(dataset_name),
            graph=sess.graph)  # graph = sess.graph, huge!
        """Main loop."""
        m_min = 0.2
        m_max = 0.9
        m = m_min
        for step in range(cfg.epoch * num_batches_per_epoch + 1):
            tic = time.time()
            """"TF queue would pop batch until no file"""
            try:
                _, loss_value, summary_str = sess.run(
                    [train_op, loss, summary_op], feed_dict={m_op: m})
                logger.info('%d iteration finished in ' % step + '%f second' %
                            (time.time() - tic) + ' loss=%f' % loss_value)
            except KeyboardInterrupt:
                sess.close()
                sys.exit()
            except tf.errors.InvalidArgumentError:
                logger.warning(
                    '%d iteration contains NaN gradients. Discard.' % step)
                continue
            else:
                """Write to summary."""
                if step % 5 == 0:
                    summary_writer.add_summary(summary_str, step)
                """Epoch wise linear annealing."""
                if (step % num_batches_per_epoch) == 0:
                    if step > 0:
                        m += (m_max - m_min) / (cfg.epoch * cfg.m_schedule)
                        if m > m_max:
                            m = m_max
                """Save intermediate model."""
                if (step % int(cfg.epoch_save * num_batches_per_epoch)) == 0:
                    """Save model periodically"""
                    ckpt_path = os.path.join(
                        cfg.logdir + '/caps/{}/'.format(dataset_name),
                        'model-{:.4f}.ckpt'.format(loss_value))
                    saver.save(sess, ckpt_path, global_step=step)
Exemplo n.º 15
0
def main(args):
    """Get dataset hyperparameters."""
    # assert len(args) == 2 and isinstance(args[1], str)
    # dataset_name = args[1]
    logger.info('Using dataset: {}'.format(dataset_name))
    """Set reproduciable random seed"""
    tf.set_random_seed(1234)

    coord_add = get_coord_add(dataset_name)
    dataset_size = get_dataset_size_train(dataset_name)
    num_classes = get_num_classes(dataset_name)
    create_inputs = get_create_inputs(dataset_name,
                                      is_train=True,
                                      epochs=cfg.epoch)

    with tf.Graph().as_default(), tf.device('/cpu:0'):
        """Get global_step."""
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)
        """Get batches per epoch."""
        num_batches_per_epoch = int(dataset_size / cfg.batch_size)
        """Use exponential decay leanring rate?"""
        lrn_rate = tf.maximum(
            tf.train.exponential_decay(1e-3, global_step,
                                       num_batches_per_epoch, 0.8), 1e-5)
        tf.summary.scalar('learning_rate', lrn_rate)
        opt = tf.train.AdamOptimizer()  # lrn_rate
        """Get batch from data queue."""
        batch_x, batch_labels = create_inputs()
        # batch_y = tf.one_hot(batch_labels, depth=10, axis=1, dtype=tf.float32)
        """Define the dataflow graph."""
        with tf.device('/gpu:0'):
            with slim.arg_scope([slim.variable], device='/cpu:0'):
                batch_x_squash = tf.divide(batch_x, 255.)
                batch_x = slim.batch_norm(batch_x,
                                          center=False,
                                          is_training=True,
                                          trainable=True)
                output = net.build_arch_baseline(batch_x,
                                                 is_train=True,
                                                 num_classes=num_classes)
                loss, recon_loss, _ = net.cross_ent_loss(
                    output, batch_x_squash, batch_labels)
                acc = net.test_accuracy(output, batch_labels)
                tf.summary.scalar('train_acc', acc)
                tf.summary.scalar('recon_loss', recon_loss)
                tf.summary.scalar('all_loss', loss)
            """Compute gradient."""
            grad = opt.compute_gradients(loss)
            # See: https://stackoverflow.com/questions/40701712/how-to-check-nan-in-gradients-in-tensorflow-when-updating
            grad_check = [
                tf.check_numerics(g, message='Gradient NaN Found!')
                for g, _ in grad if g is not None
            ] + [tf.check_numerics(loss, message='Loss NaN Found')]
        """Apply graident."""
        with tf.control_dependencies(grad_check):
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = opt.apply_gradients(grad, global_step=global_step)
        """Set Session settings."""
        session_config = tf.ConfigProto(
            device_count={'GPU': 0},
            gpu_options={
                'allow_growth': 1,
                # 'per_process_gpu_memory_fraction': 0.1,
                'visible_device_list': '0'
            },
            allow_soft_placement=True)
        sess = tf.Session(config=session_config)
        sess.run(tf.local_variables_initializer())
        sess.run(tf.global_variables_initializer())
        """Set Saver."""
        var_to_save = [
            v for v in tf.global_variables() if 'Adam' not in v.name
        ]  # Don't save redundant Adam beta/gamma
        saver = tf.train.Saver(var_list=var_to_save, max_to_keep=cfg.epoch)
        """Display parameters"""
        total_p = np.sum([
            np.prod(v.get_shape().as_list()) for v in var_to_save
        ]).astype(np.int32)
        train_p = np.sum([
            np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()
        ]).astype(np.int32)
        logger.info('Total Parameters: {}'.format(total_p))
        logger.info('Trainable Parameters: {}'.format(train_p))

        # read snapshot
        # latest = os.path.join(cfg.logdir, 'model.ckpt-4680')
        # saver.restore(sess, latest)
        """Set summary op."""
        summary_op = tf.summary.merge_all()
        """Start coord & queue."""
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        """Set summary writer"""
        if not os.path.exists(
                cfg.logdir +
                '/cnn_baseline/{}/train_log/'.format(dataset_name)):
            os.makedirs(cfg.logdir +
                        '/cnn_baseline/{}/train_log/'.format(dataset_name))
        summary_writer = tf.summary.FileWriter(
            cfg.logdir + '/cnn_baseline/{}/train_log/'.format(dataset_name),
            graph=sess.graph)
        """Main loop."""
        for step in range(cfg.epoch * num_batches_per_epoch + 1):
            tic = time.time()
            """"TF queue would pop batch until no file"""
            try:
                _, loss_value, summary_str = sess.run(
                    [train_op, loss, summary_op])
                logger.info('%d iteration finishs in ' % step + '%f second' %
                            (time.time() - tic) + ' loss=%f' % loss_value)
            except KeyboardInterrupt:
                sess.close()
                sys.exit()
            except tf.errors.InvalidArgumentError:
                logger.warning(
                    '%d iteration contains NaN gradients. Discard.' % step)
                continue
            else:
                """Write to summary."""
                if step % 5 == 0:
                    summary_writer.add_summary(summary_str, step)
                """Epoch wise linear annealling."""
                if (step % num_batches_per_epoch) == 0:
                    """Save model periodically"""
                    ckpt_path = os.path.join(
                        cfg.logdir + '/cnn_baseline/{}'.format(dataset_name),
                        'model-{:.4f}.ckpt'.format(loss_value))
                    saver.save(sess, ckpt_path, global_step=step)
Exemplo n.º 16
0
def main(args):
    """Get dataset hyperparameters."""
    assert len(args) == 2 and isinstance(args[1], str)
    dataset_name = args[1]
    logger.info('Using dataset: {}'.format(dataset_name))

    """Set reproduciable random seed"""
    tf.set_random_seed(1234)

    coord_add = get_coord_add(dataset_name)
    dataset_size = get_dataset_size_train(dataset_name)
    num_classes = get_num_classes(dataset_name)
    create_inputs = get_create_inputs(dataset_name, is_train=True, epochs=cfg.epoch)

    with tf.Graph().as_default(), tf.device('/cpu:0'):
        """Get global_step."""
        global_step = tf.get_variable(
            'global_step', [], initializer=tf.constant_initializer(0), trainable=False)

        """Get batches per epoch."""
        num_batches_per_epoch = int(dataset_size / cfg.batch_size)

        """Use exponential decay leanring rate?"""
        lrn_rate = tf.maximum(tf.train.exponential_decay(
            1e-3, global_step, num_batches_per_epoch, 0.8), 1e-5)
        tf.summary.scalar('learning_rate', lrn_rate)
        opt = tf.train.AdamOptimizer()  # lrn_rate

        """Get batch from data queue."""
        batch_x, batch_labels = create_inputs()
        # batch_y = tf.one_hot(batch_labels, depth=10, axis=1, dtype=tf.float32)

        """Define the dataflow graph."""
        with tf.device('/gpu:0'):
            with slim.arg_scope([slim.variable], device='/cpu:0'):
                batch_x_squash = tf.divide(batch_x, 255.)
                batch_x = slim.batch_norm(batch_x, center=False, is_training=True, trainable=True)
                output = net.build_arch_baseline(batch_x, is_train=True,
                                                 num_classes=num_classes)
                loss, recon_loss, _ = net.cross_ent_loss(output, batch_x_squash, batch_labels)
                acc = net.test_accuracy(output, batch_labels)
                tf.summary.scalar('train_acc', acc)
                tf.summary.scalar('recon_loss', recon_loss)
                tf.summary.scalar('all_loss', loss)

            """Compute gradient."""
            grad = opt.compute_gradients(loss)
            # See: https://stackoverflow.com/questions/40701712/how-to-check-nan-in-gradients-in-tensorflow-when-updating
            grad_check = [tf.check_numerics(g, message='Gradient NaN Found!')
                          for g, _ in grad if g is not None] + [tf.check_numerics(loss, message='Loss NaN Found')]

        """Apply graident."""
        with tf.control_dependencies(grad_check):
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = opt.apply_gradients(grad, global_step=global_step)

        """Set Session settings."""
        sess = tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True, log_device_placement=False))
        sess.run(tf.local_variables_initializer())
        sess.run(tf.global_variables_initializer())

        """Set Saver."""
        var_to_save = [v for v in tf.global_variables(
        ) if 'Adam' not in v.name]  # Don't save redundant Adam beta/gamma
        saver = tf.train.Saver(var_list=var_to_save, max_to_keep=cfg.epoch)

        """Display parameters"""
        total_p = np.sum([np.prod(v.get_shape().as_list()) for v in var_to_save]).astype(np.int32)
        train_p = np.sum([np.prod(v.get_shape().as_list())
                          for v in tf.trainable_variables()]).astype(np.int32)
        logger.info('Total Parameters: {}'.format(total_p))
        logger.info('Trainable Parameters: {}'.format(train_p))

        # read snapshot
        # latest = os.path.join(cfg.logdir, 'model.ckpt-4680')
        # saver.restore(sess, latest)
        """Set summary op."""
        summary_op = tf.summary.merge_all()

        """Start coord & queue."""
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        """Set summary writer"""
        if not os.path.exists(cfg.logdir + '/cnn_baseline/{}/train_log/'.format(dataset_name)):
            os.makedirs(cfg.logdir + '/cnn_baseline/{}/train_log/'.format(dataset_name))
        summary_writer = tf.summary.FileWriter(
            cfg.logdir + '/cnn_baseline/{}/train_log/'.format(dataset_name), graph=sess.graph)

        """Main loop."""
        for step in range(cfg.epoch * num_batches_per_epoch + 1):
            tic = time.time()
            """"TF queue would pop batch until no file"""
            try:
                _, loss_value, summary_str = sess.run(
                    [train_op, loss, summary_op])
                logger.info('%d iteration finishs in ' % step + '%f second' %
                            (time.time() - tic) + ' loss=%f' % loss_value)
            except KeyboardInterrupt:
                sess.close()
                sys.exit()
            except tf.errors.InvalidArgumentError:
                logger.warning('%d iteration contains NaN gradients. Discard.' % step)
                continue
            else:
                """Write to summary."""
                if step % 5 == 0:
                    summary_writer.add_summary(summary_str, step)

                """Epoch wise linear annealling."""
                if (step % num_batches_per_epoch) == 0:

                    """Save model periodically"""
                    ckpt_path = os.path.join(
                        cfg.logdir + '/cnn_baseline/{}'.format(dataset_name), 'model-{:.4f}.ckpt'.format(loss_value))
                    saver.save(sess, ckpt_path, global_step=step)

        """Join threads"""
        coord.join(threads)
Exemplo n.º 17
0
def main(args):
    """Get dataset hyperparameters."""
    assert len(args) == 2 and isinstance(args[1], str)
    dataset_name = args[1]
    logger.info('Using dataset: {}'.format(dataset_name))
    """Set reproduciable random seed"""
    tf.set_random_seed(1234)

    coord_add = get_coord_add(dataset_name)
    dataset_size = get_dataset_size_train(dataset_name)
    num_classes = get_num_classes(dataset_name)

    # Prepare Training Data
    (x_train, y_train), (x_test, y_test) = utils.load_mnist_excluded()

    with tf.Graph().as_default():  #, tf.device('/cpu:0'):

        # Placeholders for input data and the targets
        x_input = tf.placeholder(tf.float32, (None, *IMG_DIM), name='Input')
        y_target = tf.placeholder(tf.int32, [
            None,
        ], name='Target')
        """Get global_step."""
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)
        """Get batches per epoch."""
        num_batches_per_epoch = int(dataset_size / cfg.batch_size)
        """Use exponential decay leanring rate?"""
        lrn_rate = tf.maximum(
            tf.train.exponential_decay(1e-3, global_step,
                                       num_batches_per_epoch, 0.8), 1e-5)
        tf.summary.scalar('learning_rate', lrn_rate)
        opt = tf.train.AdamOptimizer()  # lrn_rate
        """Define the dataflow graph."""
        m_op = tf.placeholder(dtype=tf.float32, shape=())
        with tf.device('/gpu:0'):
            with slim.arg_scope([slim.variable]):  #, device='/cpu:0'):
                sample_batch = tf.identity(x_input)
                batch_labels = tf.identity(y_target)
                batch_squash = tf.divide(sample_batch, 255.)
                batch_x = slim.batch_norm(sample_batch,
                                          center=False,
                                          is_training=True,
                                          trainable=True)
                output, pose_out = net.build_arch(batch_x,
                                                  coord_add,
                                                  is_train=True,
                                                  num_classes=num_classes)

                tf.logging.debug(pose_out.get_shape())
                loss, spread_loss, mse, reconstruction = net.spread_loss(
                    output, pose_out, batch_squash, batch_labels, m_op)
                sample_batch = tf.squeeze(sample_batch)
                decode_res_op = tf.concat([
                    sample_batch,
                    255 * tf.reshape(reconstruction,
                                     [cfg.batch_size, IMAGE_SIZE, IMAGE_SIZE])
                ],
                                          axis=0)
                acc = net.test_accuracy(output, batch_labels)
                tf.summary.scalar('spread_loss', spread_loss)
                tf.summary.scalar('reconstruction_loss', mse)
                tf.summary.scalar('all_loss', loss)
                tf.summary.scalar('train__batch_acc', acc)
            """Compute gradient."""
            grad = opt.compute_gradients(loss)
            # See: https://stackoverflow.com/questions/40701712/how-to-check-nan-in-gradients-in-tensorflow-when-updating
            grad_check = [
                tf.check_numerics(g, message='Gradient NaN Found!')
                for g, _ in grad if g is not None
            ] + [tf.check_numerics(loss, message='Loss NaN Found')]
        """Apply graident."""
        with tf.control_dependencies(grad_check):
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = opt.apply_gradients(grad, global_step=global_step)
        """Set Session settings."""
        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                                log_device_placement=False))
        sess.run(tf.local_variables_initializer())
        sess.run(tf.global_variables_initializer())
        """Set Saver."""
        var_to_save = [
            v for v in tf.global_variables() if 'Adam' not in v.name
        ]  # Don't save redundant Adam beta/gamma
        saver = tf.train.Saver(var_list=var_to_save, max_to_keep=cfg.epoch)
        """Display parameters"""
        total_p = np.sum([
            np.prod(v.get_shape().as_list()) for v in var_to_save
        ]).astype(np.int32)
        train_p = np.sum([
            np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()
        ]).astype(np.int32)
        logger.info('Total Parameters: {}'.format(total_p))
        logger.info('Trainable Parameters: {}'.format(train_p))

        # read snapshot
        # latest = os.path.join(cfg.logdir, 'model.ckpt-4680')
        # saver.restore(sess, latest)
        """Set summary op."""
        summary_op = tf.summary.merge_all()
        """Start coord & queue."""
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        """Set summary writer"""
        if not os.path.exists(cfg.logdir +
                              '/caps/{}/train_log/'.format(dataset_name)):
            os.makedirs(cfg.logdir +
                        '/caps/{}/train_log/'.format(dataset_name))
        summary_writer = tf.summary.FileWriter(
            cfg.logdir + '/caps/{}/train_log/'.format(dataset_name),
            graph=sess.graph)  # graph = sess.graph, huge!

        if not os.path.exists(cfg.logdir +
                              '/caps/{}/images/'.format(dataset_name)):
            os.makedirs(cfg.logdir + '/caps/{}/images/'.format(dataset_name))
        """Main loop."""
        m_min = 0.2
        m_max = 0.9
        m = m_min
        max_iter = cfg.epoch * num_batches_per_epoch + 1

        for step in range(max_iter):
            tic = time.time()
            """"TF queue would pop batch until no file"""

            batch_x, batch_y = utils.get_random_mnist_batch(
                x_train, y_train, cfg.batch_size)

            try:
                _, loss_value, train_acc_val, summary_str, mse_value = sess.run(
                    [train_op, loss, acc, summary_op, mse],
                    feed_dict={
                        m_op: m,
                        x_input: batch_x,
                        y_target: batch_y
                    })

                sys.stdout.write(ERASE_LINE)
                sys.stdout.write('\r\r%d/%d iteration finishes in ' %
                                 (step, max_iter) + '%f second' %
                                 (time.time() - tic) +
                                 ' training accuracy = %f' % train_acc_val +
                                 ' loss=%f' % loss_value +
                                 '\treconstruction_loss=%f' % mse_value)
                sys.stdout.flush()
                time.sleep(0.001)

            except KeyboardInterrupt:
                sess.close()
                sys.exit()
            except tf.errors.InvalidArgumentError:
                logger.warning(
                    '%d iteration contains NaN gradients. Discard.' % step)
                continue
            else:
                """Write to summary."""
                if step % 10 == 0:
                    summary_writer.add_summary(summary_str, step)

                if step % 200 == 0:
                    images = sess.run(decode_res_op,
                                      feed_dict={
                                          m_op: m,
                                          x_input: batch_x,
                                          y_target: batch_y
                                      })
                    image = combine_images(images)
                    img_name = cfg.logdir + '/caps/{}/images/'.format(
                        dataset_name) + "/step_{}.png".format(str(step))
                    Image.fromarray(image.astype(np.uint8)).save(img_name)
                """Epoch wise linear annealling."""
                if (step % num_batches_per_epoch) == 0:
                    if step > 0:
                        m += (m_max - m_min) / (cfg.epoch * cfg.m_schedule)
                        if m > m_max:
                            m = m_max
                    """Save model periodically """
                    ckpt_path = os.path.join(
                        cfg.logdir + '/caps/{}/'.format(dataset_name),
                        'model-{:.4f}.ckpt'.format(loss_value))
                    saver.save(sess, ckpt_path, global_step=step)

        ckpt_path = os.path.join(cfg.logdir + '/caps/{}/'.format(dataset_name),
                                 'finall-model-{:.4f}.ckpt'.format(loss_value))
        saver.save(sess, ckpt_path, global_step=step)

        print('Training is finished!')