示例#1
0
def train():
    # Import data
    mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)

    with tf.Session() as sess:
        # Input placeholders
        with tf.name_scope('input'):
            x = tf.placeholder(tf.float32, [FLAGS.batch_size, 784],
                               name='x-input')

        # Model definition along with training and relevances
        with tf.variable_scope('model'):
            with tf.variable_scope('discriminator'):
                D = discriminator()
                D1 = D.forward(
                    x)  # Run the Discriminator with the True data distribution
                D_params_num = len(tf.trainable_variables())
            with tf.variable_scope('generator'):
                G = generator()
                Gout = G.forward(
                    tf.random_normal([FLAGS.batch_size, FLAGS.input_size
                                      ]))  # Run the generator to get Fake data

            with tf.variable_scope('discriminator') as scope:
                scope.reuse_variables()
                D2 = D.forward(
                    Gout
                )  # Run the Discriminator with the Fake data distribution

            # Image summaries
            packed = tf.concat(
                [Gout, tf.reshape(x,
                                  Gout.get_shape().as_list())], 2)
            tf.summary.image('Generated-Original',
                             packed,
                             max_outputs=FLAGS.batch_size)
            #tf.summary.image('Original', tf.reshape(x, Gout.get_shape().as_list()))

        # Extract respective parameters
        total_params = tf.trainable_variables()
        D_params = total_params[:D_params_num]
        G_params = total_params[D_params_num:]

        with tf.variable_scope('Loss'):
            # Compute every loss
            D1_loss, D2_loss = compute_D_loss(D1, D2)
            D_loss = tf.reduce_mean(D1_loss + D2_loss)
            G_loss = compute_G_loss(D2)
            # Loss summaries
            tf.summary.scalar('D_real', tf.reduce_mean(D1_loss))
            tf.summary.scalar('D_fake', tf.reduce_mean(D2_loss))
            tf.summary.scalar('D_loss', tf.reduce_mean(D_loss))
            tf.summary.scalar('G_loss', tf.reduce_mean(G_loss))

        # Create Trainers (Optimizers) for each network giving respective loss and weight parameters
        with tf.variable_scope('Trainer'):
            D_trainer = D.fit(loss=D_loss,
                              optimizer='adam',
                              opt_params=[FLAGS.D_learning_rate, D_params])
            G_trainer = G.fit(loss=G_loss,
                              optimizer='adam',
                              opt_params=[FLAGS.G_learning_rate, G_params])

        # create summaries files for D and G -
        # this is the main summaries file
        # it will store all the variables mentioned above for creating summaries
        merged = tf.summary.merge_all()
        D_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/D',
                                         sess.graph)
        G_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/G',
                                         sess.graph)

        # Init all variables
        tf.global_variables_initializer().run()

        utils = Utils(sess, FLAGS.checkpoint_dir)
        if FLAGS.reload_model:
            utils.reload_model()

        for i in range(FLAGS.max_steps):
            d = feed_dict(mnist, True)
            inp = {x: d[0]}
            # Run D once and G twice
            D_summary, _, dloss, dd1, dd2 = sess.run(
                [merged, D_trainer.train, D_loss, D1_loss, D2_loss],
                feed_dict=inp)
            G_summary, _, gloss, gen_images = sess.run(
                [merged, G_trainer.train, G_loss, Gout], feed_dict=inp)
            G_summary, _, gloss, gen_images = sess.run(
                [merged, G_trainer.train, G_loss, Gout], feed_dict=inp)

            if i % 100 == 0:
                print(gloss.mean(), dloss.mean())

            # Add summaries
            D_writer.add_summary(D_summary, i)
            G_writer.add_summary(G_summary, i)

        # save model if required
        if FLAGS.save_model:
            utils.save_model()

        D_writer.close()
        G_writer.close()
示例#2
0
def train():
    # Import data
    mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)

    with tf.Session() as sess:
        # Input placeholders
        with tf.name_scope('input'):
            x = tf.placeholder(tf.float32, [FLAGS.batch_size, 784],
                               name='x-input')
            y_ = tf.placeholder(tf.float32, [FLAGS.batch_size, 10],
                                name='y-input')
            keep_prob = tf.placeholder(tf.float32)

        # Model definition along with training and relevances
        with tf.variable_scope('model'):
            net = nn()
            y = net.forward(x)

        with tf.variable_scope('relevance'):
            if FLAGS.relevance:
                LRP = net.lrp(y, FLAGS.relevance_method, 1e-8)

                # LRP layerwise
                relevance_layerwise = []
                # R = y
                # for layer in net.modules[::-1]:
                #     R = net.lrp_layerwise(layer, R, 'simple')
                #     relevance_layerwise.append(R)
            else:
                LRP = []
                relevance_layerwise = []
        # Accuracy computation
        with tf.name_scope('correct_prediction'):
            correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
            accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        tf.summary.scalar('accuracy', accuracy)

        # Merge all the summaries and write them out
        merged = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
                                             sess.graph)
        test_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/test')

        tf.global_variables_initializer().run()

        utils = Utils(sess, FLAGS.checkpoint_dir)
        if FLAGS.reload_model:
            utils.reload_model()

        trainer = net.fit(output=y,
                          ground_truth=y_,
                          loss='softmax_crossentropy',
                          optimizer='adam',
                          opt_params=[FLAGS.learning_rate])

        uninit_vars = set(tf.global_variables()) - set(
            tf.trainable_variables())
        tf.variables_initializer(uninit_vars).run()

        # iterate over train and test data
        for i in range(FLAGS.max_steps):
            if i % FLAGS.test_every == 0:
                #pdb.set_trace()
                d = feed_dict(mnist, False)
                test_inp = {x: d[0], y_: d[1], keep_prob: d[2]}
                summary, acc, relevance_test, op, rel_layer = sess.run(
                    [merged, accuracy, LRP, y, relevance_layerwise],
                    feed_dict=test_inp)
                test_writer.add_summary(summary, i)
                print('Accuracy at step %s: %f' % (i, acc))

            else:
                d = feed_dict(mnist, True)
                inp = {x: d[0], y_: d[1], keep_prob: d[2]}
                summary, _, relevance_train, op, rel_layer = sess.run(
                    [merged, trainer.train, LRP, y, relevance_layerwise],
                    feed_dict=inp)
                train_writer.add_summary(summary, i)

        # relevances plotted with visually pleasing color schemes
        if FLAGS.relevance:
            # plot test images with relevances overlaid
            images = test_inp[test_inp.keys()[0]].reshape(
                [FLAGS.batch_size, 28, 28, 1])
            images = (images + 1) / 2.0
            plot_relevances(
                relevance_test.reshape([FLAGS.batch_size, 28, 28, 1]), images,
                test_writer)
            # plot train images with relevances overlaid
            # images = inp[inp.keys()[0]].reshape([FLAGS.batch_size,28,28,1])
            # images = (images + 1)/2.0
            # plot_relevances(relevance_train.reshape([FLAGS.batch_size,28,28,1]), images, train_writer )

        train_writer.close()
        test_writer.close()
示例#3
0
def train():
    # Import data
    # train_file_path = str(FLAGS.image_dim)+"_train_y.csv"
    # test_file_path = str(FLAGS.image_dim)+"_test_y.csv"

    # mnist = TFLData( (train_file_path,test_file_path) )

    train_file_path = os.path.join("mnist_csvs", "mnist_train.csv")
    test_file_path = os.path.join("mnist_csvs", "mnist_test.csv")

    mnist = MnistData((train_file_path, test_file_path, (1000, 1000)))

    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:

        #with tf.Session() as sess:
        # Input placeholders
        with tf.name_scope('input'):
            x = tf.placeholder(tf.float32,
                               [None, FLAGS.image_dim * FLAGS.image_dim],
                               name='x-input')
            y_ = tf.placeholder(tf.float32, [None, 10], name='y-input')
            keep_prob = tf.placeholder(tf.float32)

        with tf.variable_scope('model'):
            net = nn()
            inp = tf.pad(
                tf.reshape(
                    x,
                    [FLAGS.batch_size, FLAGS.image_dim, FLAGS.image_dim, 1]),
                [[0, 0], [2, 2], [2, 2], [0, 0]])
            op = net.forward(inp)
            y = tf.squeeze(op)

            trainer = net.fit(output=y,
                              ground_truth=y_,
                              loss='softmax_crossentropy',
                              optimizer='adam',
                              opt_params=[FLAGS.learning_rate])
        with tf.variable_scope('relevance'):
            if FLAGS.relevance:
                LRP = net.lrp(op, FLAGS.relevance_method, 1e-8)

                # LRP layerwise
                relevance_layerwise = []
                # R = y
                # for layer in net.modules[::-1]:
                #     R = net.lrp_layerwise(layer, R, 'simple')
                #     relevance_layerwise.append(R)

            else:
                LRP = []
                relevance_layerwise = []

        with tf.name_scope('accuracy'):
            accuracy = tf.reduce_mean(
                tf.cast(tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)),
                        tf.float32))
        tf.summary.scalar('accuracy', accuracy)

        # Merge all the summaries and write them out to /tmp/mnist_logs (by default)
        merged = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
                                             sess.graph)
        test_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/test')

        tf.global_variables_initializer().run()

        utils = Utils(sess, FLAGS.checkpoint_reload_dir)
        if FLAGS.reload_model:
            utils.reload_model()

        for i in range(FLAGS.max_steps):
            if i % FLAGS.test_every == 0:  # test-set accuracy
                d = feed_dict(mnist, False)
                test_inp = {x: d[0], y_: d[1], keep_prob: d[2]}
                #pdb.set_trace()
                summary, acc, relevance_test, rel_layer = sess.run(
                    [merged, accuracy, LRP, relevance_layerwise],
                    feed_dict=test_inp)

                print_y = tf.argmax(y, 1)
                y_labels = print_y.eval(feed_dict=test_inp)

                test_writer.add_summary(summary, i)
                print('Accuracy at step %s: %f' % (i, acc))
                # print([np.sum(rel) for rel in rel_layer])
                # print(np.sum(relevance_test))

                # save model if required
                if FLAGS.save_model:
                    utils.save_model()

            else:
                d = feed_dict(mnist, True)
                inp = {x: d[0], y_: d[1], keep_prob: d[2]}
                summary, _, relevance_train, op, rel_layer = sess.run(
                    [merged, trainer.train, LRP, y, relevance_layerwise],
                    feed_dict=inp)
                train_writer.add_summary(summary, i)

        # relevances plotted with visually pleasing color schemes
        if FLAGS.relevance:
            #pdb.set_trace()
            relevance_test = relevance_test[:, 2:FLAGS.image_dim + 2,
                                            2:FLAGS.image_dim + 2, :]
            # plot test images with relevances overlaid
            images = test_inp[test_inp.keys()[0]].reshape(
                [FLAGS.batch_size, FLAGS.image_dim, FLAGS.image_dim, 1])
            #images = (images + 1)/2.0
            plot_relevances(
                relevance_test.reshape(
                    [FLAGS.batch_size, FLAGS.image_dim, FLAGS.image_dim, 1]),
                images, test_writer, y_labels)

            # plot train images with relevances overlaid
            # relevance_train = relevance_train[:,2:30,2:30,:]
            # images = inp[inp.keys()[0]].reshape([FLAGS.batch_size,28,28,1])
            # plot_relevances(relevance_train.reshape([FLAGS.batch_size,28,28,1]), images, train_writer )

        train_writer.close()
        test_writer.close()
示例#4
0
def train(tag):
    # Import data
    tag = tag
    sub = 'subset' + str(tag)

    x_train_whole = []
    y_train_whole = []
    if tag == 0 or tag == 1 or tag == 2:
        tot = 8
    elif tag == 6:
        tot = 14
    elif tag == 8:
        tot = 15
    else:
        tot = 16
    x_test_pos = []
    x_test_neg = []
    for num in range(tot):
        h5f = h5py.File('./src/data/3D_data/' + sub + '_' + str(num) + '.h5',
                        'r')
        y_tmp = np.asarray(h5f['Y'])
        x_tmp = np.asarray(h5f['X'])
        if max(y_tmp) != 0:
            x_tmp_pos = x_tmp[np.where(y_tmp == 1)[0], :, :, :, :]
            if x_test_pos == []:
                x_test_pos = x_tmp_pos
            else:
                x_test_pos = np.concatenate([x_test_pos, x_tmp_pos])
            negIndex = np.random.choice(np.where(y_tmp == 0)[0],
                                        len(x_tmp_pos) * 3,
                                        replace=False)
            x_tmp_neg = x_tmp[negIndex, :, :, :, :]
            if x_test_neg == []:
                x_test_neg = x_tmp_neg
            else:
                x_test_neg = np.concatenate([x_test_neg, x_tmp_neg])

            del x_tmp_pos
            del x_tmp_neg
            del negIndex
        del x_tmp
        del y_tmp
    y_test_pos = np.ones(len(x_test_pos))
    y_test_neg = np.zeros(len(x_test_neg))

    x_test_tmp = np.concatenate([x_test_pos, x_test_neg])
    y_test_tmp = np.concatenate([y_test_pos, y_test_neg])

    idx = np.arange(0, len(y_test_tmp))
    np.random.shuffle(idx)
    x_test = np.asarray([x_test_tmp[i] for i in idx])
    y_test = np.asarray([y_test_tmp[i] for i in idx])
    del x_test_tmp
    del y_test_tmp
    del y_test_neg
    del x_test_neg
    del x_test_pos
    del y_test_pos
    print(len(x_test))
    print(len(y_test))
    sub = 'subset'
    for i in range(10):
        #for i in range(2):
        subset = sub + str(i)
        if i != tag:
            if i == 0 or i == 1 or i == 2:
                tot = 8
            elif i == 6:
                tot = 14
            elif i == 8:
                tot = 15
            else:
                tot = 16
            x_train_pos = []
            x_train_neg = []
            for num in range(tot):
                #for num in range(1):
                h5f2 = h5py.File(
                    './src/data/3D_data/' + subset + '_' + str(num) + '.h5',
                    'r')
                x_tmp = np.asarray(h5f2['X'])
                y_tmp = np.asarray(h5f2['Y'])
                if max(y_tmp) != 0:
                    x_tmp_pos = x_tmp[np.where(y_tmp == 1)[0], :, :, :, :]
                    inp90 = np.zeros_like(x_tmp_pos)
                    inp180 = np.zeros_like(x_tmp_pos)
                    inp270 = np.zeros_like(x_tmp_pos)
                    inp45 = np.zeros_like(x_tmp_pos)
                    inp135 = np.zeros_like(x_tmp_pos)
                    inp225 = np.zeros_like(x_tmp_pos)
                    inp315 = np.zeros_like(x_tmp_pos)

                    for aug in range(len(x_tmp_pos)):
                        inp90[aug, :, :, :, :] = rotate(
                            x_tmp_pos[aug, :, :, :, :], 90, reshape=False)
                        inp180[aug, :, :, :, :] = rotate(
                            x_tmp_pos[aug, :, :, :, :], 180, reshape=False)
                        inp270[aug, :, :, :, :] = rotate(
                            x_tmp_pos[aug, :, :, :, :], 270, reshape=False)
                        inp45[aug, :, :, :, :] = rotate(
                            x_tmp_pos[aug, :, :, :, :], 45, reshape=False)
                        inp135[aug, :, :, :, :] = rotate(
                            x_tmp_pos[aug, :, :, :, :], 135, reshape=False)
                        inp225[aug, :, :, :, :] = rotate(
                            x_tmp_pos[aug, :, :, :, :], 225, reshape=False)
                        inp315[aug, :, :, :, :] = rotate(
                            x_tmp_pos[aug, :, :, :, :], 315, reshape=False)

                    tmp = np.concatenate([
                        np.concatenate([
                            np.concatenate([
                                np.concatenate([
                                    np.concatenate([
                                        np.concatenate([
                                            np.concatenate([x_tmp_pos, inp90]),
                                            inp180
                                        ]), inp270
                                    ]), inp45
                                ]), inp135
                            ]), inp225
                        ]), inp315
                    ])
                    idx2 = np.arange(0, len(tmp))
                    np.random.shuffle(idx2)
                    tmp2 = np.asarray([tmp[a] for a in idx2])
                    del inp90
                    del inp180
                    del inp270
                    del inp45
                    del inp135
                    del inp225
                    del inp315
                    if x_train_pos == []:
                        x_train_pos = tmp2[0:int(len(tmp) / 4), :, :, :, :]
                    else:
                        x_train_pos = np.concatenate([
                            x_train_pos, tmp2[0:int(len(tmp) / 5), :, :, :, :]
                        ])

                    del tmp
                    negIndex = np.random.choice(np.where(y_tmp == 0)[0],
                                                len(x_tmp_pos) * 5,
                                                replace=False)
                    x_tmp_neg = x_tmp[negIndex, :, :, :, :]
                    if x_train_neg == []:
                        x_train_neg = x_tmp_neg
                    else:
                        x_train_neg = np.concatenate([x_train_neg, x_tmp_neg])

                    del tmp2
                    del x_tmp_neg
                    del x_tmp_pos
                    del negIndex
                del x_tmp
                del y_tmp
            y_train_pos = np.ones(len(x_train_pos))
            y_train_neg = np.zeros(len(x_train_neg))
            x_train_tmp = np.concatenate([x_train_pos, x_train_neg])
            y_train_tmp = np.concatenate([y_train_pos, y_train_neg])
            del x_train_pos
            del x_train_neg
            del y_train_neg
            del y_train_pos
            idx = np.arange(0, len(y_train_tmp))
            np.random.shuffle(idx)
            x_train = np.asarray([x_train_tmp[a] for a in idx])
            y_train = np.asarray([y_train_tmp[a] for a in idx])
            del x_train_tmp
            del y_train_tmp
            if x_train_whole == []:
                x_train_whole = x_train
                y_train_whole = y_train
            else:
                x_train_whole = np.concatenate([x_train_whole, x_train])
                y_train_whole = np.concatenate([y_train_whole, y_train])
            print(len(x_train_whole))

            del x_train
            del y_train
    x_train = x_train_whole
    y_train = y_train_whole
    del x_train_whole
    del y_train_whole
    print(len(x_train))
    print(len(y_train))
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        # with tf.Session() as sess:
        # Input placeholders
        with tf.name_scope('input'):
            x = tf.placeholder(tf.float32, [None, 32, 32, 32, 1],
                               name='x-input')
            y_ = tf.placeholder(tf.float32, [None, 2], name='y-input')
            phase = tf.placeholder(tf.bool, name='phase')
        with tf.variable_scope('model'):
            net = nn(phase)
            # x_prep = prep_data_augment(x)
            # x_input = data_augment(x_prep)
            inp = tf.reshape(x, [FLAGS.batch_size, 32, 32, 32, 1])
            op = net.forward(inp)
            y = tf.reshape(op, [FLAGS.batch_size, 2])
            soft = tf.nn.softmax(y)
            trainer = net.fit(output=y,
                              ground_truth=y_,
                              loss='focal loss',
                              optimizer='adam',
                              opt_params=[FLAGS.learning_rate])
        with tf.variable_scope('relevance'):
            if FLAGS.relevance:

                LRP = net.lrp(y, FLAGS.relevance_method, 1)
                # LRP layerwise
                relevance_layerwise = []
                # R = input_rel2
                # for layer in net.modules[::-1]:
                #     R = net.lrp_layerwise(layer, R, FLAGS.relevance_method, 1e-8)
                #     relevance_layerwise.append(R)

            else:
                LRP = []
                relevance_layerwise = []

        with tf.name_scope('accuracy'):
            accuracy = tf.reduce_mean(
                tf.cast(tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)),
                        tf.float32))
            # accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(tf.where(tf.greater(y,0),tf.ones_like(y, dtype=tf.float32), tf.zeros_like(y, dtype=tf.float32)), 2), tf.argmax(y_, 2)), tf.float32))
        tf.summary.scalar('accuracy', accuracy)

        # Merge all the summaries and write them out to /tmp/mnist_logs (by default)
        merged = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(
            './conv_log/' + str(tag) + '_train', sess.graph)
        test_writer = tf.summary.FileWriter('./conv_log/' + str(tag) + '_test')

        tf.global_variables_initializer().run()

        utils = Utils(sess, './3D_model/subset' + str(tag))
        if FLAGS.reload_model:
            utils.reload_model()
        train_acc = []
        test_acc = []
        for i in range(FLAGS.max_steps):

            if i % FLAGS.test_every == 0:  # test-set accuracy
                x_test_batch, y_test_batch = next_batch(
                    FLAGS.batch_size, x_test, y_test)
                tmp_y_batch = np.zeros([FLAGS.batch_size, 2])
                tmp_y_batch[:, 0] = np.ones([FLAGS.batch_size]) - y_test_batch
                tmp_y_batch[:, 1] = np.zeros([FLAGS.batch_size]) + y_test_batch
                y_test_batch = tmp_y_batch
                test_inp = {x: x_test_batch, y_: y_test_batch, phase: False}

                # pdb.set_trace()
                summary, acc, relevance_test, op2, soft_val, rel_layer = sess.run(
                    [merged, accuracy, LRP, y, soft, relevance_layerwise],
                    feed_dict=test_inp)
                test_writer.add_summary(summary, i)
                test_acc.append(acc)
                print('-----------')
                for m in range(FLAGS.batch_size):
                    print(np.argmax(y_test_batch[m, :]),
                          y_test_batch[m, :],
                          end=" ")
                    print(np.argmax(op2[m, :]), op2[m, :], end=" ")
                    print(soft_val[m, :])
                    print("|")
                print('Accuracy at step %s: %f' % (i, acc))
                print(tag)
                # print([np.sum(rel) for rel in rel_layer])
                # print(np.sum(relevance_test))

                # save model if required
                if FLAGS.save_model:
                    utils.save_model()

            else:
                x_train_batch, y_train_batch = next_batch(
                    FLAGS.batch_size, x_train, y_train)
                tmp_y_batch = np.zeros([FLAGS.batch_size, 2])
                tmp_y_batch[:, 0] = np.ones([FLAGS.batch_size]) - y_train_batch
                tmp_y_batch[:,
                            1] = np.zeros([FLAGS.batch_size]) + y_train_batch
                y_train_batch = tmp_y_batch
                inp = {x: x_train_batch, y_: y_train_batch, phase: True}
                summary, acc2, _, relevance_train, op2, soft_val, rel_layer = sess.run(
                    [
                        merged, accuracy, trainer.train, LRP, y, soft,
                        relevance_layerwise
                    ],
                    feed_dict=inp)
                train_writer.add_summary(summary, i)
                #print(soft_val[0,:])
                train_acc.append(acc2)
        print(np.mean(train_acc), np.mean(test_acc))

        # relevances plotted with visually pleasing color schemes
        if FLAGS.relevance:
            # plot test images with relevances overlaid
            images = test_inp[test_inp.keys()[0]].reshape(
                [FLAGS.batch_size, 32, 32, 32, 1])
            # images = (images + 1)/2.0
            plot_relevances(
                relevance_test.reshape([FLAGS.batch_size, 32, 32, 32, 1]),
                images, test_writer)

        train_writer.close()
        test_writer.close()
示例#5
0
def train(train_X, train_y):
    # Import data
    # mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
    # pu.db
    # config = tf.ConfigProto(
    #         device_count = {'GPU': 0}
    #     )
    with tf.Session() as sess:
        # Input placeholders
        with tf.name_scope('input'):
            x = tf.placeholder(tf.float32, [FLAGS.batch_size, 166],
                               name='x-input')
            y_ = tf.placeholder(tf.float32, [FLAGS.batch_size, 3],
                                name='y-input')
            keep_prob = tf.placeholder(tf.float32)

        # Model definition along with training and relevances
        with tf.variable_scope('model'):
            net = nn()
            # pu.db
            y = net.forward(x)

        saver = tf.train.Saver()

        with tf.variable_scope('relevance'):
            if FLAGS.relevance:
                LRP = net.lrp(y, FLAGS.relevance_method, 1)

                # LRP layerwise
                relevance_layerwise = []
                R = y
                for layer in net.modules[::-1]:
                    print("layer here: ", layer)
                    R = net.lrp_layerwise(layer, R, 'epsilon', 1)
                    relevance_layerwise.append(R)
            else:
                LRP = []
                relevance_layerwise = []

        # Accuracy computation
        with tf.name_scope('correct_prediction'):
            correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
            accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        tf.summary.scalar('accuracy', accuracy)

        # Merge all the summaries and write them out
        merged = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
                                             sess.graph)
        test_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/test')

        tf.global_variables_initializer().run()

        utils = Utils(sess, FLAGS.checkpoint_dir)
        if FLAGS.reload_model:
            utils.reload_model()

        print("y.shape: " + str(y.shape) + " y_.shape: " + str(y_.shape))
        trainer = net.fit(output=y,
                          ground_truth=y_,
                          loss='softmax_crossentropy',
                          optimizer='adam',
                          opt_params=[FLAGS.learning_rate])

        uninit_vars = set(tf.global_variables()) - set(
            tf.trainable_variables())
        tf.variables_initializer(uninit_vars).run()

        # iterate over train and test data
        for i in range(FLAGS.max_steps):
            if i % FLAGS.test_every == 0:
                #pdb.set_trace()
                d = feed_dict(train_X, train_y, False)
                final_X = tf.data.Dataset.from_tensor_slices(d[0])
                final_Y = tf.data.Dataset.from_tensor_slices(d[1])
                test_inp = {x: final_X, y_: final_Y, keep_prob: d[2]}
                summary, acc, relevance_test, op, rel_layer = sess.run(
                    [merged, accuracy, LRP, y, relevance_layerwise],
                    feed_dict=test_inp)
                test_writer.add_summary(summary, i)
                print('Accuracy at step %s: %f' % (i, acc))
                print([rel for rel in rel_layer])
                print(np.sum(relevance_test))
                saver.save(
                    sess,
                    str(FLAGS.checkpoint_dir) + "/model_epoch_" + str(i) +
                    ".ckpt")

            else:
                d = feed_dict(final_X, final_Y, True)
                inp = {x: d[0], y_: d[1], keep_prob: d[2]}
                summary, _, relevance_train, op, rel_layer = sess.run(
                    [merged, trainer.train, LRP, y, relevance_layerwise],
                    feed_dict=inp)
                train_writer.add_summary(summary, i)

        # relevances plotted with visually pleasing color schemes
        # if FLAGS.relevance:
        #     # plot test images with relevances overlaid
        #     images = d[0].reshape([FLAGS.batch_size,28,28,1])
        #     plot_relevances(relevance_test.reshape([FLAGS.batch_size,28,28,1]), images, test_writer )
        # plot train images with relevances overlaid
        # images = inp[inp.keys()[0]].reshape([FLAGS.batch_size,28,28,1])
        # images = (images + 1)/2.0
        # plot_relevances(relevance_train.reshape([FLAGS.batch_size,28,28,1]), images, train_writer )

        train_writer.close()
        test_writer.close()
示例#6
0
def train(tag):
    tag = str(tag)

    tag = int(tag)
    x_test_batch = np.load('./Demo_img/test_demo_x.npy')
    y_test_batch = np.load('./Demo_img/test_demo_y.npy')
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:

        # with tf.Session() as sess:
        # Input placeholders
        with tf.name_scope('input'):
            x = tf.placeholder(tf.float32, [None, 32, 32, 32, 1], name='x-input')
            y_ = tf.placeholder(tf.float32, [None, 2], name='y-input')
            phase = tf.placeholder(tf.bool, name='phase')
        with tf.variable_scope('model'):
            net = nn(phase)
            # x_prep = prep_data_augment(x)
            # x_input = data_augment(x_prep)
            inp = tf.reshape(x, [FLAGS.batch_size, 32, 32, 32, 1])
            op = net.forward(inp)
            y = tf.reshape(op, [FLAGS.batch_size, 2])
            soft = tf.nn.softmax(y)
        with tf.variable_scope('relevance'):
            if FLAGS.relevance:

                LRP = net.lrp(soft, FLAGS.relevance_method, 2)

                # LRP layerwise
                relevance_layerwise = []
                #R = tf.expand_dims(soft[0, :], 0)
                R = soft
                for layer in net.modules[::-1]:
                    R = net.lrp_layerwise(layer, R, FLAGS.relevance_method, 2)
                    relevance_layerwise.append(R)

            else:
                LRP = []
                relevance_layerwise = []

        with tf.name_scope('accuracy'):
            accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)), tf.float32))
            # accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(tf.where(tf.greater(y,0),tf.ones_like(y, dtype=tf.float32), tf.zeros_like(y, dtype=tf.float32)), 2), tf.argmax(y_, 2)), tf.float32))
        tf.summary.scalar('accuracy', accuracy)

        # Merge all the summaries and write them out to /tmp/mnist_logs (by default)
        merged = tf.summary.merge_all()
        test_writer = tf.summary.FileWriter('./conv_log/LRP')

        tf.global_variables_initializer().run()

        utils = Utils(sess, './3D_model/subset'+str(tag))
        if FLAGS.reload_model:
            utils.reload_model()

        test_inp = {x: x_test_batch, y_: y_test_batch, phase: False}

        # pdb.set_trace()
        relevance_test, op, soft_val, rel_layer = sess.run([LRP, y, soft, relevance_layerwise],
                                                            feed_dict=test_inp)
        for m in range(FLAGS.batch_size):
            print(soft_val[m, :])
        np.save('./Demo_img/soft.npy',soft_val)
        if FLAGS.relevance:
            # plot test images with relevances overlaid
            images = test_inp[test_inp.keys()[0]].reshape([FLAGS.batch_size, 32, 32, 32, 1])
            # images = (images + 1)/2.0
            plot_relevances(relevance_test.reshape([FLAGS.batch_size, 32, 32, 32, 1]),
                            images, test_writer)
        test_writer.close()
示例#7
0
def train():
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        with tf.name_scope('input'):
            x = tf.placeholder(tf.float32, [None, 25, 25], name='x-input')
            y_ = tf.placeholder(tf.float32, [None, 2], name='y-input')
            keep_prob = tf.placeholder(tf.float32)

        with tf.variable_scope('model'):
            net = nn()
            inp = tf.pad(tf.reshape(x, [FLAGS.batch_size, 25, 25, 1]), [[0, 0], [0, 0], [0, 0], [0, 0]])
            op = net.forward(inp)
            y = tf.squeeze(op)
            trainer = net.fit(output=y, ground_truth=y_, loss='softmax_crossentropy', optimizer='adam', opt_params=[FLAGS.learning_rate])
        with tf.variable_scope('relevance'):
            if FLAGS.relevance:
                LRP = net.lrp(op, FLAGS.relevance_method, 1e-8)
            else:
                LRP = []

        with tf.name_scope('accuracy'):
            accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)), tf.float32))
        tf.summary.scalar('accuracy', accuracy)

        train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train', sess.graph)
        test_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/test')
        tf.global_variables_initializer().run()
        utils = Utils(sess, FLAGS.checkpoint_dir)

        if FLAGS.reload_model:
            utils.reload_model()

        cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=y_))
        train_op = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(cost)
        predict_op = tf.argmax(y, 1)
        tf.initialize_all_variables().run()

        with tf.Session() as sess:
            tf.initialize_all_variables().run()
            for i in range(1000):
                sess.run(train_op, feed_dict={x: trX, y_: trY})
                if 100 * sess.run(accuracy, feed_dict={x: trX, y_: trY})==100:
                    print(100 * np.mean(np.argmax(teY, axis=1) == sess.run(predict_op, feed_dict={x: teX, y_: teY})))
            for bnum in range(int(len(trX) / FLAGS.batch_size)):
                test_inp = {x: teX, y_: teY}
                relevance_test = sess.run(LRP, feed_dict={x: teX, y_: teY})
                relevance_category = sess.run(predict_op, feed_dict={x: teX, y_: teY})
                relevance_truth = np.argmax(teY, axis=1)
                accuracy = 100 * np.mean(np.argmax(teY, axis=1) == sess.run(predict_op, feed_dict={x: teX, y_: teY}))

        if FLAGS.relevance:
            sio.savemat('LRP_result.mat',
                        {"relevance_test": relevance_test,
                         "relevance_category": relevance_category,
                         "relevance_truth": relevance_truth,
                         "accuracy": accuracy
                         })

        train_writer.close()
        test_writer.close()