Ejemplo n.º 1
0
    def image_norm(self, y, obj):

        l2norm = ((y - self.y_pred)**2).sum(axis=1, keepdims=False)
        errorterm = T.mean(l2norm)

        lambda_reg = 0.00001

        weights = 0
        for i in xrange(obj.n_layers):
            #weight = (T.sqrt((obj.dA_layers[i].W ** 2).sum())**2)
            weight = (nlinalg.trace(
                T.dot(obj.dA_layers[i].W.T,
                      obj.dA_layers[i].W)))  #Frobenius norm
            weights = weights + weight
        regterm = T.sum(weights, keepdims=False)

        return T.mean(l2norm) + lambda_reg / 2 * regterm
Ejemplo n.º 2
0
def test_SdA(finetune_lr=0.1,
             pretraining_epochs=hp_pretraining_epochs,
             pretrain_lr=0.1,
             training_epochs=100000,
             batch_size=hp_batchsize,
             patch_size=patch_size):

    datasets = load_data(tr_dataset)
    train_set = datasets[0]
    valid_set = datasets[1]
    test_set = datasets[2]
    train_set_x, train_set_y = datasets[0]
    valid_set_x, valid_set_y = datasets[1]
    test_set_x, test_set_y = datasets[2]
    datasets = []

    print '... plotting clean images'
    image = PIL.Image.fromarray(
        tile_raster_images(X=test_set_y.get_value(),
                           img_shape=patch_size,
                           tile_shape=(50, 40),
                           tile_spacing=(0, 0),
                           scale_rows_to_unit_interval=False))
    image.save('outputs/LLnet_clean.png')

    print '... plotting noisy images'
    image = PIL.Image.fromarray(
        tile_raster_images(X=test_set_x.get_value(),
                           img_shape=patch_size,
                           tile_shape=(50, 40),
                           tile_spacing=(0, 0),
                           scale_rows_to_unit_interval=False))
    image.save('outputs/LLnet_noisy.png')

    n_train_samples = train_set_x.get_value(borrow=True).shape[0]
    n_train_batches = n_train_samples / batch_size

    numpy_rng = numpy.random.RandomState(89677)
    print '... building the model'

    sda = SdA(numpy_rng=numpy_rng,
              n_ins=patch_size[0] * patch_size[1],
              hidden_layers_sizes=hp_hlsize,
              n_outs=patch_size[0] * patch_size[1])

    print '... compiling functions'

    pretraining_fns = sda.pretraining_functions(train_set_x=train_set_y,
                                                batch_size=batch_size)

    print '... pre-training the model'
    start_time = time.clock()
    for i in xrange(sda.n_layers):
        if i <= sda.n_layers / 2:

            if i == (sda.n_layers - 1):
                currentlr = pretrain_lr
            else:
                currentlr = pretrain_lr * 0.1

            for epoch in xrange(pretraining_epochs):
                c = []
                for batch_index in xrange(n_train_batches):
                    current_c = pretraining_fns[i](
                        index=batch_index,
                        corruption=hp_corruption_levels[i],
                        lr=currentlr)
                    if (batch_index % (n_train_batches / 100 + 1) == 0):
                        print '    ... Layer %i Epoch %i Progress %i/%i, Cost: %.4f, AvgCost: %.4f' % (
                            i, epoch, batch_index, n_train_batches, current_c,
                            numpy.mean(c))
                    c.append(current_c)
                print 'Pre-trained layer %i, epoch %d, cost ' % (i, epoch),
                print numpy.mean(c)

                print '     model checkpoint for current epoch...'
                f = file('outputs/model_checkpoint.obj', 'wb')
                cPickle.dump(sda, f, protocol=cPickle.HIGHEST_PROTOCOL)
                f.close()

    end_time = time.clock()

    print('... pretrained bottom half of the SdA in %.2fm' %
          ((end_time - start_time) / 60.))

    layer_all = sda.n_layers + 1  #Number of hidden layers + 1
    print layer_all

    for i in xrange(layer_all / 2 - 1):

        #Reverse map 2 to 5
        layer = i + 2
        layer_applied = layer_all - layer + 1
        print '... applying weights from SdA layer', layer, 'to SdA layer', (
            layer_applied)
        ww, bb, bbp = [
            sda.dA_layers[layer - 1].W.get_value(),
            sda.dA_layers[layer - 1].b.get_value(),
            sda.dA_layers[layer - 1].b_prime.get_value()
        ]
        sda.dA_layers[layer_applied - 1].W.set_value(ww.T)
        sda.dA_layers[layer_applied - 1].b.set_value(bbp)
        sda.dA_layers[layer_applied - 1].b_prime.set_value(bb)

    #Reverse map 1 to loglayer
    layer = 1
    print '... applying weights from SdA layer', layer, 'to loglayer layer'
    ww, bb, bbp = [
        sda.dA_layers[layer - 1].W.get_value(),
        sda.dA_layers[layer - 1].b.get_value(),
        sda.dA_layers[layer - 1].b_prime.get_value()
    ]
    sda.logLayer.W.set_value(ww.T)
    sda.logLayer.b.set_value(bbp)
    '''#Set sigmoid layer weights equal to dA weights
    for i in xrange(sda.n_layers):
	sda.sigmoid_layers[i].W.set_value(sda.dA_layers[i].W.get_value())
	sda.sigmoid_layers[i].b.set_value(sda.dA_layers[i].b.get_value())'''

    print '... compiling functions'
    train_fn, validate_model, test_model = sda.build_finetune_functions(
        train_set=train_set,
        valid_set=valid_set,
        test_set=test_set,
        batch_size=batch_size,
        learning_rate=finetune_lr)

    reconstructed = theano.function([],
                                    sda.logLayer.y_pred,
                                    givens={sda.x: test_set_x},
                                    on_unused_input='ignore')

    w1 = theano.function([],
                         nlinalg.trace(
                             T.dot(sda.sigmoid_layers[0].W.T,
                                   sda.sigmoid_layers[0].W)),
                         givens={sda.x: test_set_x},
                         on_unused_input='ignore')

    w2 = theano.function([],
                         nlinalg.trace(
                             T.dot(sda.sigmoid_layers[1].W.T,
                                   sda.sigmoid_layers[1].W)),
                         givens={sda.x: test_set_x},
                         on_unused_input='ignore')

    w3 = theano.function([],
                         nlinalg.trace(
                             T.dot(sda.sigmoid_layers[2].W.T,
                                   sda.sigmoid_layers[2].W)),
                         givens={sda.x: test_set_x},
                         on_unused_input='ignore')

    w4 = theano.function([],
                         nlinalg.trace(
                             T.dot(sda.sigmoid_layers[3].W.T,
                                   sda.sigmoid_layers[3].W)),
                         givens={sda.x: test_set_x},
                         on_unused_input='ignore')

    w5 = theano.function([],
                         nlinalg.trace(
                             T.dot(sda.sigmoid_layers[4].W.T,
                                   sda.sigmoid_layers[4].W)),
                         givens={sda.x: test_set_x},
                         on_unused_input='ignore')

    wl = theano.function([],
                         nlinalg.trace(T.dot(sda.logLayer.W.T,
                                             sda.logLayer.W)),
                         givens={sda.x: test_set_x},
                         on_unused_input='ignore')
    '''print '     loading previous model...'
    f = file('outputs/model_bestpsnr.obj', 'rb')
    sda = cPickle.load(f)
    f.close()'''

    print '... finetuning the model'
    patience = 100000 * n_train_batches
    patience_increase = 2.
    improvement_threshold = 1
    validation_frequency = min(n_train_batches, patience / 2)
    best_validation_loss = numpy.inf
    test_score = 0.
    start_time = time.clock()
    done_looping = False
    epoch = 0

    plot_valid_error = []
    ww1 = []
    ww2 = []
    ww3 = []
    ww4 = []
    ww5 = []
    wwl = []
    psnrs = []
    best_psnr = []

    while (epoch < training_epochs) and (not done_looping):
        epoch = epoch + 1

        if 1 == 0:  ########################################################################## Switch for on-the-fly training data generation

            if epoch % 50 == 0:
                print '... calling matlab function!'
                call([
                    "/usr/local/MATLAB/R2015a/bin/matlab", "-nodesktop", "-r",
                    'end2end_datagen_256; exit'
                ])
                print '... data regeneration complete, loading new data'
                datasets = load_data('dataset/llnet_17x17_OTF.mat')
                train_set = datasets[0]
                valid_set = datasets[1]
                test_set = datasets[2]
                train_set_x, train_set_y = datasets[0]
                valid_set_x, valid_set_y = datasets[1]
                test_set_x, test_set_y = datasets[2]
                datasets = []

                reconstructed = theano.function([],
                                                sda.logLayer.y_pred,
                                                givens={sda.x: test_set_x},
                                                on_unused_input='warn')

                print '... plotting clean images'
                image = PIL.Image.fromarray(
                    tile_raster_images(X=test_set_y.get_value(),
                                       img_shape=patch_size,
                                       tile_shape=(50, 40),
                                       tile_spacing=(0, 0),
                                       scale_rows_to_unit_interval=False))
                image.save('outputs/LLnet_clean.png')

                print '... plotting noisy images'
                image = PIL.Image.fromarray(
                    tile_raster_images(X=test_set_x.get_value(),
                                       img_shape=patch_size,
                                       tile_shape=(50, 40),
                                       tile_spacing=(0, 0),
                                       scale_rows_to_unit_interval=False))
                image.save('outputs/LLnet_noisy.png')

        if 1 == 1:  ########################################################################## Switch for training rate schedule change

            if epoch % 200 == 0:
                tempval = finetune_lr * 0.1
                print '... switching learning rate to %.4f, recompiling function' % (
                    tempval)
                train_fn, validate_model, test_model = sda.build_finetune_functions(
                    train_set=train_set,
                    valid_set=valid_set,
                    test_set=test_set,
                    batch_size=batch_size,
                    learning_rate=tempval)

        for minibatch_index in xrange(n_train_batches):
            minibatch_avg_cost = train_fn(minibatch_index)
            if (minibatch_index % (n_train_batches / 100 + 1) == 0):
                print '    ... FT E%i, %i/%i/%i, aCost: %.4f' % (
                    epoch, minibatch_index, n_train_batches, hp_batchsize,
                    minibatch_avg_cost)
            iter = (epoch - 1) * n_train_batches + minibatch_index

            if (iter + 1) % validation_frequency == 0:
                validation_losses = validate_model()
                this_validation_loss = numpy.mean(validation_losses)
                print(
                    'epoch %i, minibatch %i/%i, validation loss %f (best: %f)'
                    % (epoch, minibatch_index + 1, n_train_batches,
                       this_validation_loss, best_validation_loss))

                plot_valid_error.append(this_validation_loss)

                # Training monitoring tools -----------------------------------------

                ww1.append(w1())
                ww2.append(w2())
                ww3.append(w3())
                ww4.append(w4())
                ww5.append(w5())
                wwl.append(wl())

                psnr = 10 * numpy.log10(255**2 / numpy.mean(
                    numpy.sqrt(
                        numpy.sum(
                            ((test_set_y.get_value() - reconstructed()) * 255)
                            **2,
                            axis=1,
                            keepdims=True))))
                psnrs.append(psnr)

                if psnr >= numpy.max(psnrs):
                    print '     saving trained model based on highest psnr...'
                    f = file('outputs/model_bestpsnr.obj', 'wb')
                    cPickle.dump(sda, f, protocol=cPickle.HIGHEST_PROTOCOL)
                    f.close()
                    print '     plotting reconstructed images based on highest psnr...'
                    image = PIL.Image.fromarray(
                        tile_raster_images(X=reconstructed(),
                                           img_shape=patch_size,
                                           tile_shape=(50, 40),
                                           tile_spacing=(0, 0),
                                           scale_rows_to_unit_interval=False))
                    image.save('outputs/LLnet_reconstructed_bestpsnr.png')

                plt.clf()
                plt.suptitle('Epoch %d' % (epoch))
                plt.subplot(121)
                plt.plot(plot_valid_error, '-xb')
                plt.title('Validation Error, best %.4f' %
                          (numpy.min(plot_valid_error)))
                plt.subplot(122)
                plt.plot(psnrs, '-xb')
                plt.title('PSNR, best %.4f dB' % (numpy.max(psnrs)))
                if len(psnrs) > 2:
                    plt.xlabel('Rate: %.4f dB/step' % (psnrs[-1] - psnrs[-2]))
                plt.savefig('outputs/validation_error.png')

                plt.clf()
                plt.suptitle('Weight Norms, epoch %d' % (epoch))
                plt.subplot(231)
                plt.plot(ww1, '-xr')
                plt.axis('tight')
                plt.title('Layer1')
                plt.subplot(232)
                plt.plot(ww2, '-xc')
                plt.axis('tight')
                plt.title('Layer2')
                plt.subplot(233)
                plt.plot(ww3, '-xy')
                plt.axis('tight')
                plt.title('Layer3')
                plt.subplot(234)
                plt.plot(ww4, '-xg')
                plt.axis('tight')
                plt.title('Layer4')
                plt.subplot(235)
                plt.plot(ww5, '-xb')
                plt.axis('tight')
                plt.title('Layer5')
                plt.subplot(236)
                plt.plot(wwl, '-xm')
                plt.axis('tight')
                plt.title('Sigmoid Layer')
                plt.savefig('outputs/weightnorms.png')

                # Training monitoring tools -----------------------------------------

                if this_validation_loss < best_validation_loss:
                    if (this_validation_loss <
                            best_validation_loss * improvement_threshold):
                        patience = max(patience, iter * patience_increase)

                    best_validation_loss = this_validation_loss
                    best_iter = iter
                    test_losses = test_model()
                    test_score = numpy.mean(test_losses)
                    print(('     epoch %i, minibatch %i/%i, test loss of '
                           'best model %f') % (epoch, minibatch_index + 1,
                                               n_train_batches, test_score))

                    print '     saving trained model based on lowest validation error...'
                    f = file('outputs/model.obj', 'wb')
                    cPickle.dump(sda, f, protocol=cPickle.HIGHEST_PROTOCOL)
                    f.close()

                    print '     plotting reconstructed images...'
                    image = PIL.Image.fromarray(
                        tile_raster_images(X=reconstructed(),
                                           img_shape=patch_size,
                                           tile_shape=(50, 40),
                                           tile_spacing=(0, 0),
                                           scale_rows_to_unit_interval=False))
                    image.save('outputs/LLnet_reconstructed.png')

                    print '     plotting complete. Training next epoch...'

            if patience <= iter:
                done_looping = True
                break

    end_time = time.clock()
    print(('Optimization complete with best validation loss of %f, '
           'on iteration %i, '
           'with test performance %f') %
          (best_validation_loss, best_iter + 1, test_score))
    print >> sys.stderr, ('The training code for file ' +
                          os.path.split(__file__)[1] + ' ran for %.2fm' %
                          ((end_time - start_time) / 60.))