img = (np.float32(img) / 255.0 - 0.5) * 2.0
        if args.learn_norm:
            _, cost, pred, accu, weights, nf, nw = sess.run([train_op, loss_op, correct_pred, accuracy,\
                    retrieval_layer.out,\
                    retrieval_layer.norm_feature,  retrieval_layer.norm_weight],\
                    feed_dict={img_place_holder: img, label_place_holder: label, \
                    alpha_place_holder:alpha, lr_place_holder: learning_rate})
        else:
            _, cost, pred, accu, weights = sess.run([train_op, loss_op, correct_pred, accuracy,\
                retrieval_layer.out],\
                feed_dict={img_place_holder: img, label_place_holder: label,\
                alpha_place_holder:alpha, lr_place_holder: learning_rate})

        ptr += real_size

    logger.log_value('loss', cost, step=epoch)
    logger.log_value('acc', accu, step=epoch)
    if args.learn_norm:
        logger.log_value('nf', nf, step=epoch)
        logger.log_value('nw', nw, step=epoch)

    if args.alpha_decay:
        if (epoch + 1) % 10 == 0:
            learning_rate = learning_rate * 0.562
            alpha = alpha / 1.41
        print('Current alpha: {}'.format(alpha))

    #heat up
    if args.heat_up:
        if epoch == args.nb_epoch - args.nb_hu_epoch:
            learning_rate = learning_rate * 0.1
        test_embedding, test_normed_embedding, test_logits, test_labels\
                = get_feature(mnist.test, n_classes)
        test_class_id = np.argmax(test_labels, axis=1)
        model_weights = retrieval_layer.out.eval(sess)
        normed_model_weights = model_weights.copy()
        for i in range(model_weights.shape[1]):
            normed_model_weights[:, i] = model_weights[:, i] / np.sqrt(
                np.sum(model_weights[:, i]**2) + 1e-4)

        test_nmi = utils.eval_nmi(test_embedding, test_class_id, n_classes)
        normed_test_nmi = utils.eval_nmi(test_normed_embedding, test_class_id,
                                         n_classes)
        test_acc = utils.eval_all_acc(test_logits, test_class_id)
        print('Test Accuracy: {}'.format(test_acc))
        print('Test NMI: {}'.format(test_nmi))
        logger.log_value('Test Accuracy', test_acc, step=epoch)
        logger.log_value('Test NMI', test_nmi, step=epoch)
        logger.log_value('Normed Test NMI', normed_test_nmi, step=epoch)
        logger.log_value('loss', avg_cost, step=epoch)
        logger.log_value('avg acc', avg_acc, step=epoch)

        saver.save(sess, '{}/{}/check_point'.format(args.model_dir, suffix),\
                global_step=epoch)

        if args.draw:
            save_png_name = '{}/{}/test_epoch{:03d}.png'.format(
                args.save_image_dir, suffix, epoch)
            draw_feature_and_weights(save_png_name, test_normed_embedding, \
                    normed_model_weights*0.9, test_class_id)
            if not args.l2_norm:
                save_png_name = '{}/{}/test_no_l2n_epoch{:03d}.png'.format(
Пример #3
0
    batch_size = args.batch_size,
    epochs     = args.nb_epoch,
    validation_data  = (tst.X, tst.Y),
    callbacks  = cbs,
)

model.compile(loss='categorical_crossentropy', optimizer=args.optimizer, metrics=['accuracy'])
print(model.get_config())
hist=None
try:
    r = model.fit(**fit_args)
    hist = r.history
    #print(hist)
    for key, value_list in hist.iteritems():
        for idx, value in enumerate(value_list):
            logger.log_value(key, value, step = idx)

except KeyboardInterrupt:
    print("KeyboardInterrupt called")
    

# Print and save results
probs = 0.
get_IB_layer_output = keras.backend.function([model.layers[0].input],[model.layers[2].output])

for _ in range(args.predict_samples):
    probs += model.predict(tst.X)

probs /= float(args.predict_samples)
preds = probs.argmax(axis=-1)
print('Accuracy (using %d samples): %0.5f' % (args.predict_samples, np.mean(preds == tst.y)))
Пример #4
0
            for i in range(real_size):
                img[i,:,:,:] = img_data[ptr+i,y_offset[i]:(y_offset[i]+default_image_size),\
                            x_offset[i]:(x_offset[i]+default_image_size),:]
                if flip_flag[i]:
                    img[i, :, :, :] = img[i, :, ::-1, :]
        else:
            img = img_data[ptr:ptr + real_size, :, :, :]

        label = one_hot_label[ptr:ptr + real_size, :]
        _, cost, accu = sess.run([train_op, loss_op, accuracy],\
                        feed_dict={google_net_model.img: img, label_place_holder: label})

        global_step += 1
        ptr += training_batch_size
        if (step + 1) % args.display_step == 0:
            logger.log_value('loss', cost, step=log_step)
            logger.log_value('acc', accu, step=log_step)
            log_step = log_step + 1

    training_embedding = get_feature(img_data)
    training_nmi = eval_nmi(training_embedding, class_id,
                            num_training_category)
    logger.log_value('Training NMI', training_nmi, step=log_step)

    valid_embedding = get_feature(valid_img_data)
    validation_nmi = eval_nmi(valid_embedding, valid_class_id,
                              num_valid_category)
    logger.log_value('Test NMI', validation_nmi, step=log_step)

#saver = tf.train.Saver()
#saver.save(self.sess, '../tensorflow_model/model.ckpt')