예제 #1
0
def train(data_path, adaptor, classifier, summ):

    input_ = Input(FLAGS.train_batch_size, FLAGS.num_points)
    waves, labels = input_(data_path)

    # Calculate the loss of the model.
    if FLAGS.adp:
        logits = tf.stop_gradient(adaptor(waves))
        # logits = adaptor(waves)
        logits = classifier(logits)
    else:
        logits = classifier(waves, expand_dims=True)

    loss = LossClassification(FLAGS.num_classes)(logits, labels)

    opt = Adam(FLAGS.learning_rate,
               lr_decay=True,
               lr_decay_steps=FLAGS.lr_decay_steps,
               lr_decay_factor=FLAGS.lr_decay_factor)

    graph_regularizers = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    total_regularization_loss = tf.reduce_sum(graph_regularizers)

    train_op = opt(loss + total_regularization_loss)

    summ.register('train', 'train_loss', loss)

    train_summ_op = summ('train')

    return loss, train_op, train_summ_op
예제 #2
0
파일: inspect_dist.py 프로젝트: mingyr/san
def main(unused_argv):

    if FLAGS.data_file == '' or not os.path.isfile(FLAGS.data_file):
        raise ValueError('invalid data file')

    if FLAGS.output_dir == '' or not os.path.exists(FLAGS.output_dir):
        raise ValueError('invalid output directory {}'.format(
            FLAGS.output_dir))

    checkpoint_dir = os.path.join(FLAGS.output_dir, '')

    print('reconstructing models and inputs.')
    input_ = Input(1, [FLAGS.img_height, FLAGS.img_width])

    inputs, labels = input_(FLAGS.data_file)

    adaptor = Adaptor()
    mapper = Mapper()

    logits = adaptor(inputs)
    logits = mapper(logits)

    variables = snt.get_variables_in_module(
        adaptor) + snt.get_variables_in_module(mapper)
    saver_adaptor = tf.train.Saver(snt.get_variables_in_module(adaptor))
    saver = tf.train.Saver(variables)

    with tf.Session() as sess:
        sess.run(tf.local_variables_initializer())

        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            # Restores from checkpoint
            saver.restore(sess, ckpt.model_checkpoint_path)
            # Assuming model_checkpoint_path looks something like:
            #   /my-favorite-path/cifar10_train/model.ckpt-0,
            # extract global_step from it.
            global_step = ckpt.model_checkpoint_path.split('/')[-1].split(
                '-')[-1]
        else:
            print('No checkpoint file found')
            return

        assert (FLAGS.inspect_size > 0), "invalid test samples"

        logits_lst = []
        labels_lst = []

        for i in range(FLAGS.inspect_size):
            logit_val, label_val = sess.run([logits, labels])

            logits_lst.append(logit_val)
            labels_lst.append(label_val)

        io.savemat("stats.mat", {"logits": logits_lst, 'classes': labels_lst})
예제 #3
0
파일: train_adaptor.py 프로젝트: mingyr/san
def train(data_path, summ):

    input_target = PseudoInput(FLAGS.batch_size)
    input_source = Input(FLAGS.batch_size, [FLAGS.img_height, FLAGS.img_width])

    adaptor = Adaptor()
    mapper = Mapper(summ=summ)
    discriminator = Discriminator()

    inputs, _ = input_source(data_path)

    # Calculate the loss of the model.

    feat = adaptor(inputs)
    feat = mapper(feat)

    target = input_target()
    summ.register('train', 'target_dist', target)

    logits_source = discriminator(feat)
    logits_target = discriminator(target)

    # WGAN Loss
    d_loss_real = -tf.reduce_mean(logits_target)
    d_loss_fake = tf.reduce_mean(logits_source)
    d_loss = d_loss_real + d_loss_fake

    # Total generator loss.
    g_loss = -d_loss_fake

    summ.register("train", "Discriminator_loss", d_loss)
    summ.register("train", "Generator_loss", g_loss)

    generator_vars = snt.get_variables_in_module(
        adaptor) + snt.get_variables_in_module(mapper)
    discriminator_vars = snt.get_variables_in_module(discriminator)

    g_optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.adp_learning_rate,
                                         name='g_opt',
                                         beta1=FLAGS.beta1,
                                         beta2=FLAGS.beta2).minimize(
                                             g_loss, var_list=generator_vars)

    d_optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.adp_learning_rate,
                                         name='d_opt',
                                         beta1=FLAGS.beta1,
                                         beta2=FLAGS.beta2).minimize(
                                             d_loss,
                                             var_list=discriminator_vars)

    summ_op = summ('train')

    return g_optimizer, d_optimizer, summ_op
예제 #4
0
파일: model.py 프로젝트: mingyr/san
def test_adaptor():
    print("test adaptor")

    from config import FLAGS
    from input_2 import Input
    input_ = Input(32, [FLAGS.img_height, FLAGS.img_width])
    adaptor = Adaptor(FLAGS.num_filters)

    inputs, labels = input_('mnist/mnist.tfr') 
    outputs = adaptor(inputs)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        v = sess.run(outputs)

        print(v.shape)
예제 #5
0
파일: model_wgan.py 프로젝트: mingyr/san
def test_adaptor():
    print("test adaptor")

    from config import FLAGS
    from input_2 import Input
    input_ = Input(32, FLAGS.num_points)
    adaptor = Adaptor(FLAGS.num_filters)

    waves, labels = input_(
        '/data/yuming/eeg-processed-data/vep/san-1d/eeg.tfr')
    outputs = adaptor(waves)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        v = sess.run(outputs)

        print(v.shape)
예제 #6
0
파일: model_wgan.py 프로젝트: mingyr/san
def test_classifier():
    print("test classifier")

    from config import FLAGS

    from input_2 import Input
    input_ = Input(32, FLAGS.num_points)

    waves, labels = input_(
        '/data/yuming/eeg-processed-data/vep/san-1d/eeg.tfr')

    classifier = Classifier(FLAGS.num_points, FLAGS.sampling_rate)

    outputs = classifier(waves)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        v = sess.run(outputs)

        print(v.shape)
예제 #7
0
def xval(data_path, adaptor, classifier, summ):

    input_ = Input(FLAGS.xval_batch_size, FLAGS.num_points)
    waves, labels = input_(data_path)

    # Calculate the loss of the model.
    if FLAGS.adp:
        logits = adaptor(waves)
        logits = classifier(logits)
    else:
        logits = classifier(waves, expand_dims=True)

    logits = tf.argmax(logits, axis=-1)

    metrics = Metrics("accuracy")
    with tf.control_dependencies(
        [tf.assert_equal(tf.rank(labels), tf.rank(logits))]):
        score, xval_accu_op = metrics(labels, logits)

    assert summ, "invalid summary helper object"
    summ.register('xval', 'accuracy', score)
    xval_summ_op = summ('xval')

    return xval_accu_op, xval_summ_op
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.ERROR)

    # tf.logging.set_verbosity(3)  # Print INFO log messages.

    if FLAGS.data_dir == '':
        raise ValueError('invalid file name {}'.format(FLAGS.data_dir))
    else:
        data_filenames = FLAGS.data_dir.split(",")
        exists = [os.path.isfile(filename) for filename in data_filenames]
        indices = [i for i, b in enumerate(exists) if not b]
        if len(indices) > 0:
            raise ValueError('invalid file name {}'.format(
                data_filenames[indices[0]]))

    checkpoint_dir = os.path.join(FLAGS.output_dir, '')

    test_input = Input(1,
                       [FLAGS.img_height, FLAGS.img_width, FLAGS.num_channels])

    filename_tensor = tf.placeholder(tf.string, shape=[None])
    inputs, labels = test_input(filename_tensor)

    model = Model(act=FLAGS.activation,
                  pool=FLAGS.pooling,
                  with_memory=FLAGS.with_memory,
                  log=True)

    logits = model(inputs, training=False)

    logit_indices = tf.argmax(logits, axis=-1)

    # Define the metric and update operations
    metrics = Metrics("accuracy")
    metric_op, metric_update_op = metrics(labels, logit_indices)

    saver = tf.train.Saver(tf.trainable_variables())

    with tf.Session() as sess:
        sess.run([tf.local_variables_initializer()])

        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            # Restores from checkpoint.
            saver.restore(sess, ckpt.model_checkpoint_path)

            # Assuming model_checkpoint_path looks something like:
            #   /my-favorite-path/imagenet_train/model.ckpt-0,
            # extract global_step from it.
            # global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
            # print('Successfully loaded model from %s at step = %s.' %
            #       (ckpt.model_checkpoint_path, global_step))

            print('Successfully loaded model from %s.' %
                  ckpt.model_checkpoint_path)

        else:
            print('No checkpoint file found')
            return

        for filename in data_filenames:
            sess.run([metric_update_op, logits],
                     feed_dict={filename_tensor: [filename]})

            accu = sess.run(metric_op)
            print("accu -> {}".format(accu))

            reset_metrics(sess)
예제 #9
0
파일: test.py 프로젝트: mingyr/san
def main(unused_argv):

    if FLAGS.data_dir == '' or not os.path.exists(FLAGS.data_dir):
        raise ValueError('invalid data directory')

    if FLAGS.evaluate:
        print("evaluate the model")
        data_path = os.path.join(FLAGS.data_dir, 'eeg-xval.tfr')
    else:
        print("model inference")
        data_path = os.path.join(FLAGS.data_dir, 'eeg-test.tfr')

    if FLAGS.output_dir == '' or not os.path.exists(FLAGS.output_dir):
        raise ValueError('invalid output directory {}'.format(FLAGS.output_dir))

    checkpoint_dir = os.path.join(FLAGS.output_dir, '')

    print('reconstructing models and inputs.')
    input_ = Input(1, FLAGS.num_points)

    waves, labels = input_(data_path)

    if FLAGS.adp:
        adaptor = Adaptor()
        classifier = ReducedClassifier()

        logits = adaptor(waves)
        logits = classifier(logits)
    else:

        classifier = Classifier(FLAGS.num_points, FLAGS.sampling_rate)
        logits = classifier(waves, expand_dims = True)

    # Calculate the loss of the model.
    logits = tf.argmax(logits, axis = -1)
    
    metrics = Metrics("accuracy")
    with tf.control_dependencies([tf.assert_equal(tf.rank(labels), tf.rank(logits))]):
        metric_op, metric_update_op = metrics(labels, logits)
   
    if FLAGS.adp:
        variables = snt.get_variables_in_module(adaptor) + snt.get_variables_in_module(classifier)
        saver_adaptor = tf.train.Saver(snt.get_variables_in_module(adaptor))
        saver = tf.train.Saver(variables)
    else:
        saver = tf.train.Saver(snt.get_variables_in_module(classifier))

    with tf.Session() as sess:
        sess.run(tf.local_variables_initializer())

        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            # Restores from checkpoint
            saver.restore(sess, ckpt.model_checkpoint_path)
            # Assuming model_checkpoint_path looks something like:
            #   /my-favorite-path/cifar10_train/model.ckpt-0,
            # extract global_step from it.
            global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
        else:
            print('No checkpoint file found')
            return

        assert (FLAGS.test_size > 0), "invalid test samples"
        for i in range(FLAGS.test_size):
            sess.run(metric_update_op)

        metric = sess.run(metric_op)
        print("metric -> {}".format(metric))