Пример #1
0
def fit(data, filter_sizes, out_channels, strides, paddings,
        smooth_weights, sparse_weights, readout_sparse_weight,
        learning_rate=0.001, max_iter=10000, val_steps=50,
        early_stopping_steps=10):
    '''Fit CNN model.
    
    Parameters:
        data:                  Dataset object (see load_data())
        filter_sizes:          Filter sizes (list containing one number per conv layer)
        out_channels:          Number of output channels (list; one number per conv layer)
        strides:               Strides (list; one number per conv layer)
        paddings:              Paddings (list; one number per conv layer; VALID|SAME)
        smooth_weights:        Weights for smoothness regularizer (list; one number per conv layer)
        sparse_weights:        Weights for group sparsity regularizer (list; one number per conv layer)
        readout_sparse_weight: Sparisty of readout weights (scalar)
        learning_rate:         Learning rate (default: 0.001)   
        max_iter:              Max. number of iterations (default: 10000)
        val_steps:             Validation interval (number of iterations; default: 50)
        early_stopping_steps:  Tolerance for early stopping. Will stop optimizing 
            after this number of validation steps without decrease of loss.

    Output:
        cnn:                   A fitted ConvNet object
    '''
    cnn = ConvNet(data, log_dir='cnn', log_hash='manual')
    cnn.build(filter_sizes=filter_sizes,
              out_channels=out_channels,
              strides=strides,
              paddings=paddings,
              smooth_weights=smooth_weights,
              sparse_weights=sparse_weights,
              readout_sparse_weight=readout_sparse_weight)
    for lr_decay in range(3):
        training = cnn.train(max_iter=max_iter,
                             val_steps=val_steps,
                             early_stopping_steps=early_stopping_steps,
                             learning_rate=learning_rate)
        for (i, (logl, readout_sparse, conv_sparse, smooth, total_loss, pred)) in training:
            print('Step %d | Loss: %.2f | Poisson: %.2f | L1 readout: %.2f | Sparse: %.2f | Smooth: %.2f | Var(y): %.3f' % \
                  (i, total_loss, logl, readout_sparse, conv_sparse, smooth, np.mean(np.var(pred, axis=0))))
        learning_rate /= 3
        print('Reducing learning rate to %f' % learning_rate)
    print('Done fitting')
    return cnn