Exemplo n.º 1
0
def main(args):
    """Get dataset hyperparameters."""
    assert isinstance(args[1], str)
    dataset_name = args[1]
    logger.info('Using dataset: {}'.format(dataset_name))

    """Set reproduciable random seed"""
    tf.set_random_seed(1234)

    dataset_size = get_dataset_size_train(dataset_name)
    num_classes = get_num_classes(dataset_name)
    create_inputs = get_create_inputs(dataset_name, is_train=True, epochs=cfg.epoch)

    with tf.Graph().as_default(), tf.device('/cpu:0'):
        """Get global_step."""
        global_step = tf.get_variable(
            'global_step', [], initializer=tf.constant_initializer(0), trainable=False)

        """Get batches per epoch."""
        num_batches_per_epoch = int(dataset_size / cfg.batch_size)

        opt = tf.train.AdamOptimizer()  # lrn_rate

        """Get batch from data queue."""
        batch_x, batch_labels = create_inputs()
        # batch_y = tf.one_hot(batch_labels, depth=10, axis=1, dtype=tf.float32)

        """Define the dataflow graph."""
        with tf.device('/gpu:0'):
            with slim.arg_scope([slim.variable], device='/cpu:0'):
                batch_squash = tf.divide(batch_x, 255.)
                # batch_x = slim.batch_norm(batch_x, center=False, is_training=True, trainable=True)
                output, output_len = net.build_arch(batch_squash, is_train=True, num_classes=num_classes)
                tf.logging.debug(output.get_shape())
                loss, margin_loss, mse, _ = net.loss(
                    output, output_len, batch_squash, batch_labels)
                acc = net.test_accuracy(output_len, batch_labels)
                tf.summary.scalar('margin_loss', margin_loss)
                tf.summary.scalar('reconstruction_loss', mse)
                tf.summary.scalar('all_loss', loss)
                tf.summary.scalar('train_acc', acc)

            """Compute gradient."""
            grad = opt.compute_gradients(loss)
            # See: https://stackoverflow.com/questions/40701712/how-to-check-nan-in-gradients-in-tensorflow-when-updating
            grad_check = [tf.check_numerics(g, message='Gradient NaN Found!')
                          for g, _ in grad if g is not None] + [tf.check_numerics(loss, message='Loss NaN Found')]

        """Apply graident."""
        with tf.control_dependencies(grad_check):
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = opt.apply_gradients(grad, global_step=global_step)

        """Set Session settings."""
        sess = tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True, log_device_placement=False))
        sess.run(tf.local_variables_initializer())
        sess.run(tf.global_variables_initializer())

        """Set Saver."""
        var_to_save = [v for v in tf.global_variables(
        ) if 'Adam' not in v.name]  # Don't save redundant Adam beta/gamma
        saver = tf.train.Saver(var_list=var_to_save, max_to_keep=cfg.epoch)

        """Display parameters"""
        total_p = np.sum([np.prod(v.get_shape().as_list()) for v in var_to_save]).astype(np.int32)
        train_p = np.sum([np.prod(v.get_shape().as_list())
                          for v in tf.trainable_variables()]).astype(np.int32)
        logger.info('Total Parameters: {}'.format(total_p))
        logger.info('Trainable Parameters: {}'.format(train_p))

        # read snapshot
        # latest = os.path.join(cfg.logdir, 'model.ckpt-4680')
        # saver.restore(sess, latest)
        """Set summary op."""
        summary_op = tf.summary.merge_all()

        """Start coord & queue."""
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        """Set summary writer"""
        if not os.path.exists(cfg.logdir + '/caps/{}/train_log/'.format(dataset_name)):
            os.makedirs(cfg.logdir + '/caps/{}/train_log/'.format(dataset_name))
        summary_writer = tf.summary.FileWriter(
            cfg.logdir + '/caps/{}/train_log/'.format(dataset_name), graph=sess.graph)  # graph = sess.graph, huge!

        """Main loop."""
        for step in range(cfg.epoch * num_batches_per_epoch + 1):
            tic = time.time()
            """"TF queue would pop batch until no file"""
            try:
                _, loss_value, summary_str = sess.run(
                    [train_op, loss, summary_op])

                logger.info('%d iteration finishs in ' % step + '%f second' %
                            (time.time() - tic) + ' loss=%f' % loss_value)
            except KeyboardInterrupt:
                sess.close()
                sys.exit()
            except tf.errors.InvalidArgumentError:
                logger.warning('%d iteration contains NaN gradients. Discard.' % step)
                continue
            else:
                """Write to summary."""
                if step % 5 == 0:
                    summary_writer.add_summary(summary_str, step)

                if (step % num_batches_per_epoch) == 0:
                    """Save model periodically"""
                    ckpt_path = os.path.join(
                        cfg.logdir + '/caps/{}/'.format(dataset_name), 'model-{:.4f}.ckpt'.format(loss_value))
                    saver.save(sess, ckpt_path, global_step=step)
Exemplo n.º 2
0
def main(args):
    """Get dataset hyperparameters."""
    assert isinstance(args[1], str)
    dataset_name = args[1]
    logger.info('Using dataset: {}'.format(dataset_name))
    """Set reproduciable random seed"""
    tf.set_random_seed(1234)

    dataset_size = get_dataset_size_train(dataset_name)
    num_classes = get_num_classes(dataset_name)
    create_inputs = get_create_inputs(dataset_name,
                                      is_train=True,
                                      epochs=cfg.epoch)

    with tf.Graph().as_default(), tf.device('/cpu:0'):
        """Get global_step."""
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)
        """Get batches per epoch."""
        num_batches_per_epoch = int(dataset_size / cfg.batch_size)

        opt = tf.train.AdamOptimizer()  # lrn_rate
        """Get batch from data queue."""
        batch_x, batch_labels = create_inputs()
        # batch_y = tf.one_hot(batch_labels, depth=10, axis=1, dtype=tf.float32)
        """Define the dataflow graph."""
        with tf.device('/gpu:0'):
            with slim.arg_scope([slim.variable], device='/cpu:0'):
                batch_squash = tf.divide(batch_x, 255.)
                # batch_x = slim.batch_norm(batch_x, center=False, is_training=True, trainable=True)
                output, output_len = net.build_arch(batch_squash,
                                                    is_train=True,
                                                    num_classes=num_classes)
                tf.logging.debug(output.get_shape())
                loss, margin_loss, mse, _ = net.loss(output, output_len,
                                                     batch_squash,
                                                     batch_labels)
                acc = net.test_accuracy(output_len, batch_labels)
                tf.summary.scalar('margin_loss', margin_loss)
                tf.summary.scalar('reconstruction_loss', mse)
                tf.summary.scalar('all_loss', loss)
                tf.summary.scalar('train_acc', acc)
            """Compute gradient."""
            grad = opt.compute_gradients(loss)
            # See: https://stackoverflow.com/questions/40701712/how-to-check-nan-in-gradients-in-tensorflow-when-updating
            grad_check = [
                tf.check_numerics(g, message='Gradient NaN Found!')
                for g, _ in grad if g is not None
            ] + [tf.check_numerics(loss, message='Loss NaN Found')]
        """Apply graident."""
        with tf.control_dependencies(grad_check):
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = opt.apply_gradients(grad, global_step=global_step)
        """Set Session settings."""
        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                                log_device_placement=False))
        sess.run(tf.local_variables_initializer())
        sess.run(tf.global_variables_initializer())
        """Set Saver."""
        var_to_save = [
            v for v in tf.global_variables() if 'Adam' not in v.name
        ]  # Don't save redundant Adam beta/gamma
        saver = tf.train.Saver(var_list=var_to_save, max_to_keep=cfg.epoch)
        """Display parameters"""
        total_p = np.sum([
            np.prod(v.get_shape().as_list()) for v in var_to_save
        ]).astype(np.int32)
        train_p = np.sum([
            np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()
        ]).astype(np.int32)
        logger.info('Total Parameters: {}'.format(total_p))
        logger.info('Trainable Parameters: {}'.format(train_p))

        # read snapshot
        # latest = os.path.join(cfg.logdir, 'model.ckpt-4680')
        # saver.restore(sess, latest)
        """Set summary op."""
        summary_op = tf.summary.merge_all()
        """Start coord & queue."""
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        """Set summary writer"""
        if not os.path.exists(cfg.logdir +
                              '/caps/{}/train_log/'.format(dataset_name)):
            os.makedirs(cfg.logdir +
                        '/caps/{}/train_log/'.format(dataset_name))
        summary_writer = tf.summary.FileWriter(
            cfg.logdir + '/caps/{}/train_log/'.format(dataset_name),
            graph=sess.graph)  # graph = sess.graph, huge!
        """Main loop."""
        for step in range(cfg.epoch * num_batches_per_epoch + 1):
            tic = time.time()
            """"TF queue would pop batch until no file"""
            try:
                _, loss_value, summary_str = sess.run(
                    [train_op, loss, summary_op])

                logger.info('%d iteration finishs in ' % step + '%f second' %
                            (time.time() - tic) + ' loss=%f' % loss_value)
            except KeyboardInterrupt:
                sess.close()
                sys.exit()
            except tf.errors.InvalidArgumentError:
                logger.warning(
                    '%d iteration contains NaN gradients. Discard.' % step)
                continue
            else:
                """Write to summary."""
                if step % 5 == 0:
                    summary_writer.add_summary(summary_str, step)

                if (step % num_batches_per_epoch) == 0:
                    """Save model periodically"""
                    ckpt_path = os.path.join(
                        cfg.logdir + '/caps/{}/'.format(dataset_name),
                        'model-{:.4f}.ckpt'.format(loss_value))
                    saver.save(sess, ckpt_path, global_step=step)
Exemplo n.º 3
0
def main(args):
    """Get dataset hyperparameters."""
    assert len(args) == 3 and isinstance(args[1], str) and isinstance(
        args[2], str)
    dataset_name = args[1]
    model_name = args[2]
    """Set reproduciable random seed"""
    tf.set_random_seed(1234)

    coord_add = get_coord_add(dataset_name)
    dataset_size_train = get_dataset_size_train(dataset_name)
    dataset_size_test = get_dataset_size_test(dataset_name)
    num_classes = get_num_classes(dataset_name)
    create_inputs = get_create_inputs(dataset_name,
                                      is_train=False,
                                      epochs=cfg.epoch)

    with tf.Graph().as_default():
        num_batches_per_epoch_train = int(dataset_size_train / cfg.batch_size)
        num_batches_test = 2  # int(dataset_size_test / cfg.batch_size * 0.1)

        batch_x, batch_labels = create_inputs()
        batch_squash = tf.divide(batch_x, 255.)
        # batch_x_norm = slim.batch_norm(batch_x, center=False, is_training=False, trainable=False)
        pose_out, output = net.build_arch(batch_squash,
                                          is_train=False,
                                          num_classes=num_classes)
        # output, pose_out = net.build_arch(batch_x_norm, coord_add, is_train=False, num_classes=num_classes)
        tf.logging.debug(pose_out.get_shape())

        batch_acc = net.test_accuracy(output, batch_labels)
        # m_op = tf.constant(0.9)
        loss, margin_loss, mse, recon_img_squash = net.loss(
            pose_out, output, batch_squash, batch_labels)
        tf.summary.scalar('spread_loss', margin_loss)
        tf.summary.scalar('reconstruction_loss', mse)
        tf.summary.scalar('all_loss', loss)
        data_size = int(batch_x.get_shape()[1])
        recon_img = tf.multiply(
            tf.reshape(recon_img_squash,
                       shape=[cfg.batch_size, data_size, data_size, 1]), 255.)
        orig_img = tf.reshape(batch_x,
                              shape=[cfg.batch_size, data_size, data_size, 1])
        tf.summary.image('orig_image', orig_img)
        tf.summary.image('recon_image', recon_img)
        saver = tf.train.Saver()

        step = 0

        tf.summary.scalar('accuracy', batch_acc)
        summary_op = tf.summary.merge_all()

        with tf.Session(
                config=tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=False)) as sess:
            sess.run(tf.local_variables_initializer())
            sess.run(tf.global_variables_initializer())

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            if not os.path.exists(cfg.test_logdir +
                                  '/{}/{}/'.format(model_name, dataset_name)):
                os.makedirs(cfg.test_logdir +
                            '/{}/{}/'.format(model_name, dataset_name))
            summary_writer = tf.summary.FileWriter(
                cfg.test_logdir + '/{}/{}/'.format(model_name, dataset_name),
                graph=sess.graph)  # graph=sess.graph, huge!

            files = os.listdir(cfg.logdir +
                               '/{}/{}/'.format(model_name, dataset_name))
            for epoch in range(5, 6):
                # requires a regex to adapt the loss value in the file name here
                ckpt_re = ".ckpt-%d" % (num_batches_per_epoch_train * epoch)
                for __file in files:
                    if __file.endswith(ckpt_re + ".index"):
                        ckpt = os.path.join(
                            cfg.logdir +
                            '/{}/{}/'.format(model_name, dataset_name),
                            __file[:-6])
                # ckpt = os.path.join(cfg.logdir, "model.ckpt-%d" % (num_batches_per_epoch_train * epoch))
                saver.restore(sess, ckpt)

                accuracy_sum = 0
                for i in range(num_batches_test):
                    batch_acc_v, summary_str, orig_image, recon_image = sess.run(
                        [batch_acc, summary_op, orig_img, recon_img])
                    print('%d batches are tested.' % step)
                    summary_writer.add_summary(summary_str, step)

                    accuracy_sum += batch_acc_v

                    step += 1
                    # display original/reconstructed images in matplotlib
                    plot_imgs(orig_image, i, 'ori')
                    plot_imgs(recon_image, i, 'rec')

                ave_acc = accuracy_sum / num_batches_test
                print('the average accuracy is %f' % ave_acc)