예제 #1
0
def inference(dataset, segm_net, learn_step=0.005, num_iter=500,
              dae_dict_updates= {}, training_dict={}, data_augmentation=False,
              which_set='test', ae_h=False, full_im_ft=False,
              savepath=None, loadpath=None, test_from_0_255=False):

    #
    # Update DAE parameters
    #
    dae_dict = {'kind': 'fcn8',
                'dropout': 0.0,
                'skip': True,
                'unpool_type':'standard',
                'n_filters': 64,
                'conv_before_pool': 1,
                'additional_pool': 0,
                'concat_h': ['input'],
                'noise': 0.0,
                'from_gt': True,
                'temperature': 1.0,
                'layer': 'probs_dimshuffle',
                'exp_name': '',
                'bn': 0}

    dae_dict.update(dae_dict_updates)

    #
    # Prepare load/save directories
    #
    exp_name = build_experiment_name(segm_net, data_aug=data_augmentation, ae_h=ae_h,
                                     **dict(dae_dict.items() + training_dict.items()))
    # exp_name += '_ftsmall' if full_im_ft else ''

    if savepath is None:
        raise ValueError('A saving directory must be specified')

    savepath = os.path.join(savepath, dataset, exp_name, 'img_plots',
                            which_set)
    loadpath = os.path.join(loadpath, dataset, exp_name)
    if not os.path.exists(savepath):
        os.makedirs(savepath)
    else:
        print('\033[93m The following folder already exists {}. '
              'It will be overwritten in a few seconds...\033[0m'.format(
                  savepath))

    print('Saving directory : ' + savepath)
    with open(os.path.join(savepath, "config.txt"), "w") as f:
        for key, value in locals().items():
            f.write('{} = {}\n'.format(key, value))

    #
    # Define symbolic variables
    #
    input_x_var = T.tensor4('input_x_var')  # tensor for input image batch
    input_concat_h_vars = [T.tensor4()] * len(dae_dict['concat_h'])  # tensor for hidden repr batch (input dae)
    y_hat_var = T.tensor4('pred_y_var')
    target_var = T.tensor4('target_var')  # tensor for target batch

    #
    # Build dataset iterator
    #
    data_iter = load_data(dataset, {}, one_hot=True, batch_size=[10, 5, 10],
                          return_0_255=test_from_0_255, which_set=which_set)

    colors = data_iter.cmap
    n_batches_test = data_iter.nbatches
    n_classes = data_iter.non_void_nclasses
    void_labels = data_iter.void_labels
    nb_in_channels = data_iter.data_shape[0]
    void = n_classes if any(void_labels) else n_classes+1

    #
    # Build networks
    #

    # Build segmentation network
    print 'Building segmentation network'
    if segm_net == 'fcn8':
        fcn = buildFCN8(nb_in_channels, input_var=input_x_var,
                        n_classes=n_classes, void_labels=void_labels,
                        path_weights=WEIGHTS_PATH+dataset+'/fcn8_model.npz',
                        trainable=False, load_weights=True,
                        layer=dae_dict['concat_h']+[dae_dict['layer']])
        padding = 100
    elif segm_net == 'densenet':
        fcn  = build_fcdensenet(input_x_var, nb_in_channels=nb_in_channels,
                                n_classes=n_classes, layer=dae_dict['concat_h'])
        padding = 0
    elif segm_net == 'fcn_fcresnet':
        raise NotImplementedError
    else:
        raise ValueError

    # Build DAE with pre-trained weights
    print 'Building DAE network'
    if dae_dict['kind'] == 'standard':
        nb_features_to_concat=fcn[0].output_shape[1]
        dae = buildDAE(input_concat_h_vars, y_hat_var, n_classes,
                       nb_features_to_concat=nb_features_to_concat,
                       padding=padding, trainable=True,
                       void_labels=void_labels, load_weights=True,
                       path_weights=loadpath, model_name='dae_model_best.npz',
                       out_nonlin=softmax, concat_h=dae_dict['concat_h'],
                       noise=dae_dict['noise'], n_filters=dae_dict['n_filters'],
                       conv_before_pool=dae_dict['conv_before_pool'],
                       additional_pool=dae_dict['additional_pool'],
                       dropout=dae_dict['dropout'], skip=dae_dict['skip'],
                       unpool_type=dae_dict['unpool_type'],
                       bn=dae_dict['bn'])
    elif dae_dict['kind'] == 'fcn8':
        dae = buildFCN8_DAE(input_concat_h_vars, y_hat_var, n_classes,
                            nb_in_channels=n_classes, path_weights=loadpath,
                            model_name='dae_model_best.npz', trainable=True,
                            load_weights=True, pretrained=True, pascal=False,
                            concat_h=dae_dict['concat_h'], noise=dae_dict['noise'])
    elif dae_dict['kind'] == 'contextmod':
        dae = buildDAE_contextmod(input_concat_h_vars, y_hat_var, n_classes,
                                  path_weights=loadpath,
                                  model_name='dae_model_best.npz',
                                  trainable=True, load_weights=True,
                                  out_nonlin=softmax, noise=dae_dict['noise'],
                                  concat_h=dae_dict['concat_h'])
    else:
        raise ValueError('Unknown dae kind')

    #
    # Define and compile theano functions
    #
    print "Defining and compiling theano functions"

    # predictions and theano functions
    pred_fcn = lasagne.layers.get_output(fcn, deterministic=True, batch_norm_use_averages=False)
    pred_fcn_fn = theano.function([input_x_var], pred_fcn)
    pred_dae = lasagne.layers.get_output(dae, deterministic=True)
    pred_dae_fn = theano.function(input_concat_h_vars+[y_hat_var], pred_dae)

    # Reshape iterative inference output to b01,c
    y_hat_dimshuffle = y_hat_var.dimshuffle((0, 2, 3, 1))
    sh = y_hat_dimshuffle.shape
    y_hat_2D = y_hat_dimshuffle.reshape((T.prod(sh[:3]), sh[3]))

    # Reshape iterative inference output to b01,c
    target_var_dimshuffle = target_var.dimshuffle((0, 2, 3, 1))
    sh2 = target_var_dimshuffle.shape
    target_var_2D = target_var_dimshuffle.reshape((T.prod(sh2[:3]), sh2[3]))

    # derivative of energy wrt input and theano function
    de = - (pred_dae - y_hat_var)
    de_fn = theano.function(input_concat_h_vars+[y_hat_var], de)

    # metrics and theano functions
    test_loss =  squared_error(y_hat_var, target_var, void)
    test_acc = accuracy(y_hat_2D, target_var_2D, void_labels, one_hot=True)
    test_jacc = jaccard(y_hat_2D, target_var_2D, n_classes, one_hot=True)
    val_fn = theano.function([y_hat_var, target_var], [test_acc, test_jacc, test_loss])

    #
    # Infer
    #
    print 'Start infering'
    rec_tot = 0
    rec_tot_fcn = 0
    rec_tot_dae = 0
    acc_tot = 0
    acc_tot_fcn = 0
    jacc_tot = 0
    jacc_tot_fcn = 0
    acc_tot_dae = 0
    jacc_tot_dae = 0
    print 'Inference step: '+str(learn_step)+ 'num iter '+str(num_iter)
    for i in range(n_batches_test):
        info_str = "Batch %d out of %d" % (i+1, n_batches_test)
        print '-'*30
        print '*'*5 + info_str + '*'*5
        print '-'*30

        # Get minibatch
        X_test_batch, L_test_batch = data_iter.next()
        L_test_batch = L_test_batch.astype(_FLOATX)

        # Compute fcn prediction y and h
        pred_test_batch = pred_fcn_fn(X_test_batch)
        Y_test_batch = pred_test_batch[-1]
        H_test_batch = pred_test_batch[:-1]

        # Compute metrics before iterative inference
        acc_fcn, jacc_fcn, rec_fcn = val_fn(Y_test_batch, L_test_batch)
        acc_tot_fcn += acc_fcn
        jacc_tot_fcn += jacc_fcn
        rec_tot_fcn += rec_fcn
        Y_test_batch_fcn = Y_test_batch
        print_results('>>>>> FCN:', rec_tot_fcn, acc_tot_fcn, jacc_tot_fcn, i+1)

        # Compute dae output and metrics after dae
        Y_test_batch_dae = pred_dae_fn(*(H_test_batch+[Y_test_batch]))
        acc_dae, jacc_dae, rec_dae = val_fn(Y_test_batch_dae, L_test_batch)
        acc_tot_dae += acc_dae
        jacc_tot_dae += jacc_dae
        rec_tot_dae += rec_dae
        print_results('>>>>> FCN+DAE:', rec_tot_dae, acc_tot_dae, jacc_tot_dae, i+1)

        Y_test_batch_ii = []
        for im in range(X_test_batch.shape[0]):
            print('-----------------------')
            h_im = [el[np.newaxis, im] for el in H_test_batch]
            y_im = Y_test_batch[np.newaxis, im]
            t_im = L_test_batch[np.newaxis, im]

            # Iterative inference
            for it in range(num_iter):
                # Compute gradient
                grad = de_fn(*(h_im+[y_im]))

                # Update prediction
                y_im = y_im - learn_step * grad

                # Clip prediction
                y_im = np.clip(y_im, 0.0, 1.0)

                norm = np.linalg.norm(grad, axis=1).mean()
                if norm < _EPSILON:
                    break

                acc_iter, jacc_iter, rec_iter = val_fn(y_im, t_im)
                print rec_iter, acc_iter, np.nanmean(jacc_iter[0, :]/jacc_iter[1, :])

            Y_test_batch_ii += [y_im]

        Y_test_batch_ii = np.concatenate(Y_test_batch_ii, axis=0)

        # Compute metrics
        acc, jacc, rec = val_fn(Y_test_batch_ii, L_test_batch)
        acc_tot += acc
        jacc_tot += jacc
        rec_tot += rec
        print_results('>>>>> ITERATIVE INFERENCE:', rec_tot, acc_tot, jacc_tot, i+1)

        np.savez(savepath+'batch'+str(i)+'.npz', X=X_test_batch, L=L_test_batch,
                 Y_ii=Y_test_batch_ii, Y_fcn=Y_test_batch_fcn)
        # Save images
        # save_img(X_test_batch,
        #         L_test_batch,
        #         Y_test_batch_ii,
        #         Y_test_batch_fcn,
        #         savepath, 'batch' + str(i),
        #         void_labels, colors)

    # Print summary of how things went
    print('-------------------------------------------------------------------')
    print('------------------------------SUMMARY------------------------------')
    print('-------------------------------------------------------------------')
    print_results('>>>>> FCN:', rec_tot_fcn, acc_tot_fcn, jacc_tot_fcn, i+1)
    print_results('>>>>> FCN+DAE:', rec_tot_dae, acc_tot_dae, jacc_tot_dae, i+1)
    print_results('>>>>> ITERATIVE INFERENCE:', rec_tot, acc_tot, jacc_tot, i+1)

    # Compute per class jaccard
    jacc_perclass_fcn = jacc_tot_fcn[0, :]/jacc_tot_fcn[1, :]
    jacc_perclass = jacc_tot[0, :]/jacc_tot[1, :]

    print ">>>>> Per class jaccard:"
    labs = data_iter.mask_labels

    for i in range(len(labs)-len(void_labels)):
        class_str = '    ' + labs[i] + ' : fcn ->  %f, ii ->  %f'
        class_str = class_str % (jacc_perclass_fcn[i], jacc_perclass[i])
        print class_str

    # Move segmentations
    if savepath != loadpath:
        print('Copying images to {}'.format(loadpath))
        copy_tree(savepath, os.path.join(loadpath, 'img_plots', which_set))
def test(dataset, segm_net, which_set='val', data_aug=False,
         savepath=None, loadpath=None, test_from_0_255=False):

    #
    # Define symbolic variables
    #
    input_x_var = T.tensor4('input_var')
    target_var = T.tensor4('target_var')

    #
    # Build dataset iterator
    #
    data_iter = load_data(dataset, {}, one_hot=True, batch_size=[10, 10, 10],
                          return_0_255=test_from_0_255, which_set=which_set)

    colors = data_iter.cmap
    n_batches_test = data_iter.nbatches
    n_classes = data_iter.non_void_nclasses
    void_labels = data_iter.void_labels
    nb_in_channels = data_iter.data_shape[0]
    void = n_classes if any(void_labels) else n_classes+1

    #
    # Build segmentation network
    #
    print ' Building segmentation network'
    if segm_net == 'fcn8':
        fcn = buildFCN8(nb_in_channels, input_var=input_x_var,
                        n_classes=n_classes, void_labels=void_labels,
                        path_weights=WEIGHTS_PATH+dataset+'/fcn8_model.npz',
                        trainable=False, load_weights=True,
                        layer=['probs_dimshuffle'])
    elif segm_net == 'densenet':
        fcn  = build_fcdensenet(input_x_var, nb_in_channels=nb_in_channels,
                                n_classes=n_classes, layer=[], output_d='4d')
    elif segm_net == 'fcn_fcresnet':
        raise NotImplementedError
    else:
        raise ValueError

    #
    # Define and compile theano functions
    #
    print "Defining and compiling test functions"
    test_prediction = lasagne.layers.get_output(fcn, deterministic=True, batch_norm_use_averages=False)[0]

    # Reshape iterative inference output to b01,c
    test_prediction_dimshuffle = test_prediction.dimshuffle((0, 2, 3, 1))
    sh = test_prediction_dimshuffle.shape
    test_prediction_2D = test_prediction_dimshuffle.reshape((T.prod(sh[:3]), sh[3]))

    # Reshape iterative inference output to b01,c
    target_var_dimshuffle = target_var.dimshuffle((0, 2, 3, 1))
    sh2 = target_var_dimshuffle.shape
    target_var_2D = target_var_dimshuffle.reshape((T.prod(sh2[:3]), sh2[3]))

    test_loss =  squared_error(test_prediction, target_var, void)
    test_acc = accuracy(test_prediction_2D, target_var_2D, void_labels, one_hot=True)
    test_jacc = jaccard(test_prediction_2D, target_var_2D, n_classes, one_hot=True)

    val_fn = theano.function([input_x_var, target_var], [test_acc, test_jacc, test_loss])
    pred_fcn_fn = theano.function([input_x_var], test_prediction)

    # Iterate over test and compute metrics
    print "Testing"
    acc_test_tot = 0
    mse_test_tot = 0
    jacc_num_test_tot = np.zeros((1, n_classes))
    jacc_denom_test_tot = np.zeros((1, n_classes))
    for i in range(n_batches_test):
        # Get minibatch
        X_test_batch, L_test_batch = data_iter.next()
        Y_test_batch = pred_fcn_fn(X_test_batch)
        L_test_batch = L_test_batch.astype(_FLOATX)
        # L_test_batch = np.reshape(L_test_batch, np.prod(L_test_batch.shape))

        # Test step
        acc_test, jacc_test, mse_test = val_fn(X_test_batch, L_test_batch)
        jacc_num_test, jacc_denom_test = jacc_test

        acc_test_tot += acc_test
        mse_test_tot += mse_test
        jacc_num_test_tot += jacc_num_test
        jacc_denom_test_tot += jacc_denom_test

        # Save images
        # save_img(X_test_batch, L_test_batch, Y_test_batch,
        #          savepath, n_classes, 'batch' + str(i),
    #              void_labels, colors)

    acc_test = acc_test_tot/n_batches_test
    mse_test = mse_test_tot/n_batches_test
    jacc_per_class = jacc_num_test_tot / jacc_denom_test_tot
    jacc_per_class = jacc_per_class[0]
    jacc_test = np.mean(jacc_per_class)

    out_str = "FINAL MODEL: acc test %f, jacc test %f, mse test %f"
    out_str = out_str % (acc_test, jacc_test, mse_test)
    print out_str

    print ">>> Per class jaccard:"
    labs = data_iter.mask_labels

    for i in range(len(labs)-len(void_labels)):
        class_str = '    ' + labs[i] + ' : %f'
        class_str = class_str % (jacc_per_class[i])
        print class_str
def inference(dataset,
              segm_net,
              which_set='val',
              num_iter=5,
              Bilateral=True,
              savepath=None,
              loadpath=None,
              test_from_0_255=False):

    #
    # Define symbolic variables
    #
    input_x_var = T.tensor4('input_x_var')
    y_hat_var = T.tensor4('pred_y_var')
    target_var = T.tensor4('target_var')

    #
    # Build dataset iterator
    #
    data_iter = load_data(dataset, {},
                          one_hot=True,
                          batch_size=[10, 10, 10],
                          return_0_255=test_from_0_255,
                          which_set=which_set)

    colors = data_iter.cmap
    n_batches_test = data_iter.nbatches
    n_classes = data_iter.non_void_nclasses
    void_labels = data_iter.void_labels

    #
    # Prepare saving directory
    #
    savepath = os.path.join(savepath, dataset, segm_net, 'img_plots', 'crf',
                            str(num_iter), which_set)
    loadpath = os.path.join(loadpath, dataset, segm_net, 'img_plots', 'crf',
                            str(num_iter), which_set)
    if not os.path.exists(savepath):
        os.makedirs(savepath)

    #
    # Build network
    #
    print 'Building segmentation network'
    if segm_net == 'fcn8':
        fcn = buildFCN8(3,
                        input_var=input_x_var,
                        n_classes=n_classes,
                        void_labels=void_labels,
                        path_weights=WEIGHTS_PATH + dataset +
                        '/fcn8_model.npz',
                        trainable=False,
                        load_weights=True,
                        layer=['probs_dimshuffle'])
        padding = 100
    elif segm_net == 'densenet':
        fcn = build_fcdensenet(input_x_var,
                               nb_in_channels=3,
                               n_classes=n_classes,
                               layer=[])
        padding = 0
    elif segm_net == 'fcn_fcresnet':
        raise NotImplementedError
    else:
        raise ValueError

    #
    # Define and compile theano functions
    #
    print "Defining and compiling theano functions"

    # predictions of fcn
    pred_fcn = lasagne.layers.get_output(fcn,
                                         deterministic=True,
                                         batch_norm_use_averages=False)[0]

    # function to compute output of fcn
    pred_fcn_fn = theano.function([input_x_var], pred_fcn)

    # reshape fcn output to b,01c
    y_hat_dimshuffle = y_hat_var.dimshuffle((0, 2, 3, 1))
    sh = y_hat_dimshuffle.shape
    y_hat_2D = y_hat_dimshuffle.reshape((T.prod(sh[:3]), sh[3]))

    # reshape target to b01,c
    target_var_dimshuffle = target_var.dimshuffle((0, 2, 3, 1))
    sh2 = target_var_dimshuffle.shape
    target_var_2D = target_var_dimshuffle.reshape((T.prod(sh2[:3]), sh2[3]))

    # metrics to evaluate iterative inference
    test_acc = accuracy(y_hat_2D, target_var_2D, void_labels, one_hot=True)
    test_jacc = jaccard(y_hat_2D, target_var_2D, n_classes, one_hot=True)

    # functions to compute metrics
    val_fn = theano.function([y_hat_var, target_var], [test_acc, test_jacc])

    #
    # Infer
    #
    print 'Start infering'
    acc_tot_crf = 0
    acc_tot_fcn = 0
    jacc_tot_crf = 0
    jacc_tot_fcn = 0
    for i in range(n_batches_test):
        info_str = "Batch %d out of %d" % (i + 1, n_batches_test)
        print info_str

        # Get minibatch
        X_test_batch, L_test_batch = data_iter.next()
        L_test_batch = L_test_batch.astype(_FLOATX)

        # Compute fcn prediction
        Y_test_batch = pred_fcn_fn(X_test_batch)

        # Compute metrics before CRF
        acc_fcn, jacc_fcn = val_fn(Y_test_batch, L_test_batch)
        acc_tot_fcn += acc_fcn
        jacc_tot_fcn += jacc_fcn
        Y_test_batch_fcn = Y_test_batch
        Y_test_batch_crf = []

        for im in range(X_test_batch.shape[0]):
            # CRF
            d = dcrf.DenseCRF2D(Y_test_batch.shape[3], Y_test_batch.shape[2],
                                n_classes)
            sm = Y_test_batch[im, 0:n_classes, :, :]
            sm = sm.reshape((n_classes, -1))
            img = X_test_batch[im]
            img = np.transpose(img, (1, 2, 0))
            img = (255 * img).astype('uint8')
            img2 = np.asarray(img, order='C')

            # set unary potentials (neg log probability)
            U = unary_from_softmax(sm)
            d.setUnaryEnergy(U)

            # set pairwise potentials

            # This adds the color-independent term, features are the
            # locations only. Smoothness kernel.
            # sxy: gaussian x, y std
            # compat: ways to weight contributions, a number for potts compatibility,
            #     vector for diagonal compatibility, an array for matrix compatibility
            # kernel: kernel used, CONST_KERNEL, FULL_KERNEL, DIAG_KERNEL
            # normalization: NORMALIZE_AFTER, NORMALIZE_BEFORE,
            #     NO_NORMALIZAITION, NORMALIZE_SYMMETRIC
            d.addPairwiseGaussian(sxy=(3, 3),
                                  compat=3,
                                  kernel=dcrf.DIAG_KERNEL,
                                  normalization=dcrf.NORMALIZE_SYMMETRIC)

            if Bilateral:
                # Appearance kernel. This adds the color-dependent term, i.e. features
                # are (x,y,r,g,b).
                # im is an image-array, e.g. im.dtype == np.uint8 and im.shape == (640,480,3)
                # to set sxy and srgb perform grid search on validation set
                d.addPairwiseBilateral(sxy=(3, 3),
                                       srgb=(13, 13, 13),
                                       rgbim=img2,
                                       compat=10,
                                       kernel=dcrf.DIAG_KERNEL,
                                       normalization=dcrf.NORMALIZE_SYMMETRIC)

            # inference
            Q = d.inference(num_iter)
            Q = np.reshape(
                Q, (n_classes, Y_test_batch.shape[2], Y_test_batch.shape[3]))
            Y_test_batch_crf += [np.expand_dims(Q, axis=0)]

        # Save images
        Y_test_batch = np.concatenate(Y_test_batch_crf, axis=0)

        # Compute metrics after CRF
        acc_crf, jacc_crf = val_fn(Y_test_batch, L_test_batch)
        acc_tot_crf += acc_crf
        jacc_tot_crf += jacc_crf

        # save_img(X_test_batch.astype(_FLOATX), L_test_batch, Y_test_batch,
        #         Y_test_batch_fcn, savepath, 'batch' + str(i), void_labels, colors)
        np.savez(savepath + 'batch' + str(i) + '.npz',
                 X=X_test_batch,
                 L=L_test_batch,
                 Y_crf=Y_test_batch,
                 Y_fcn=Y_test_batch_fcn)

    acc_test_crf = acc_tot_crf / (n_batches_test)
    jacc_test_perclass_crf = jacc_tot_crf[0, :] / jacc_tot_crf[1, :]
    jacc_test_crf = np.nanmean(jacc_test_perclass_crf)

    acc_test_fcn = acc_tot_fcn / n_batches_test
    jacc_test_perclass_fcn = jacc_tot_fcn[0, :] / jacc_tot_fcn[1, :]
    jacc_test_fcn = np.nanmean(jacc_test_perclass_fcn)

    out_str = "TEST: acc crf %f, jacc crf %f, acc fcn %f, jacc fcn %f"
    out_str = out_str % (acc_test_crf, jacc_test_crf, acc_test_fcn,
                         jacc_test_fcn)

    print ">>>>> Per class jaccard:"
    labs = data_iter.mask_labels

    for i in range(len(labs) - len(void_labels)):
        class_str = '    ' + labs[i] + ' : fcn ->  %f, crf ->  %f'
        class_str = class_str % (jacc_test_perclass_fcn[i],
                                 jacc_test_perclass_crf[i])
        print class_str

    print out_str

    # Move segmentations
    if savepath != loadpath:
        print('Copying images to {}'.format(loadpath))
        copy_tree(savepath, loadpath)

    return jacc_test_perclass_crf
예제 #4
0
def train(dataset, segm_net, learning_rate=0.005, lr_anneal=1.0,
          weight_decay=1e-4, num_epochs=500, max_patience=100,
          optimizer='rmsprop', training_loss=['squared_error'],
          batch_size=[10, 1, 1], ae_h=False,
          dae_dict_updates={}, data_augmentation={},
          savepath=None, loadpath=None, resume=False, train_from_0_255=False,
          lmb=1, full_im_ft=False):

    #
    # Update DAE parameters
    #
    dae_dict = {'kind': 'fcn8',
                'dropout': 0.0,
                'skip': True,
                'unpool_type': 'standard',
                'n_filters': 64,
                'conv_before_pool': 1,
                'additional_pool': 0,
                'concat_h': ['input'],
                'noise': 0.0,
                'from_gt': True,
                'temperature': 1.0,
                'path_weights': '',
                'layer': 'probs_dimshuffle',
                'exp_name': '',
                'bn': 0}

    dae_dict.update(dae_dict_updates)

    #
    # Prepare load/save directories
    #
    exp_name = build_experiment_name(segm_net,
                                     training_loss=training_loss,
                                     data_aug=bool(data_augmentation),
                                     learning_rate=learning_rate,
                                     lr_anneal=lr_anneal,
                                     weight_decay=weight_decay,
                                     optimizer=optimizer, ae_h=ae_h,
                                     **dae_dict)
    if savepath is None:
        raise ValueError('A saving directory must be specified')

    loadpath_init = os.path.join(loadpath, dataset, exp_name)
    exp_name += '_ft' if full_im_ft else ''
    loadpath = os.path.join(loadpath, dataset, exp_name)
    savepath = os.path.join(savepath, dataset, exp_name)
    if not os.path.exists(savepath):
        os.makedirs(savepath)
    else:
        print('\033[93m The following folder already exists {}. '
              'It will be overwritten in a few seconds...\033[0m'.format(
                  savepath))

    print('Saving directory : ' + savepath)
    with open(os.path.join(savepath, "config.txt"), "w") as f:
        for key, value in locals().items():
            f.write('{} = {}\n'.format(key, value))

    #
    # Define symbolic variables
    #
    input_x_var = T.tensor4('input_x_var')  # tensor for input image batch
    input_mask_var = T.tensor4('input_mask_var')  # tensor for segmentation bach (input dae)
    input_concat_h_vars = [T.tensor4()] * len(dae_dict['concat_h'])  # tensor for hidden repr batch (input dae)
    target_var = T.tensor4('target_var')  # tensor for target batch
    # learning_rate = learning_rate*0.1 if full_im_ft else learning_rate
    # learning_rate = 0.01
    print learning_rate
    lr = theano.shared(np.float32(learning_rate), 'learning_rate')

    #
    # Build dataset iterator
    #
    train_iter, val_iter, _ = load_data(dataset,
                                        data_augmentation,
                                        one_hot=True,
                                        batch_size=batch_size,
                                        return_0_255=train_from_0_255,
                                        )

    n_batches_train = train_iter.nbatches
    n_batches_val = val_iter.nbatches
    n_classes = train_iter.non_void_nclasses
    void_labels = train_iter.void_labels
    nb_in_channels = train_iter.data_shape[0]
    void = n_classes if any(void_labels) else n_classes+1

    #
    # Build networks
    #

    # Check that model and dataset get along
    print 'Checking options'
    assert (segm_net == 'fcn8' and dataset == 'camvid') or \
        (segm_net == 'densenet' and dataset == 'camvid')
    assert (data_augmentation['crop_size'] == None and full_im_ft) or not full_im_ft

    # Build segmentation network
    print 'Building segmentation network'
    if segm_net == 'fcn8':
        layer_out = copy.copy(dae_dict['concat_h'])
        layer_out += [copy.copy(dae_dict['layer'])] if not dae_dict['from_gt'] else []
        fcn = buildFCN8(nb_in_channels, input_x_var, n_classes=n_classes,
                        void_labels=void_labels,
                        path_weights=WEIGHTS_PATH+dataset+'/fcn8_model.npz',
                        load_weights=True,
                        layer=layer_out)
        padding = 100
    elif segm_net == 'densenet':
        fcn = build_fcdensenet(input_x_var, nb_in_channels=nb_in_channels,
                                n_classes=n_classes,
                               layer=dae_dict['concat_h'],
                               from_gt=dae_dict['from_gt'])
        padding = 0
    elif segm_net == 'fcn_fcresnet':
        raise NotImplementedError
    else:
        raise ValueError

    # Build DAE network
    print 'Building DAE network'

    if ae_h and dae_dict['kind'] != 'standard':
        raise ValueError('Plug&Play not implemented for ' + dae_dict['kind'])
    if ae_h and 'pool' not in dae_dict['concat_h'][-1]:
        raise ValueError('Plug&Play version needs concat_h to be different than input')
    ae_h = ae_h and 'pool' in dae_dict['concat_h'][-1]

    if dae_dict['kind'] == 'standard':
        nb_features_to_concat=fcn[0].output_shape[1]
        dae = buildDAE(input_concat_h_vars, input_mask_var, n_classes,
                       nb_features_to_concat=nb_features_to_concat, padding=padding,
                       trainable=True,
                       void_labels=void_labels, load_weights=resume or full_im_ft,
                       path_weights=loadpath_init, model_name='dae_model_best.npz',
                       out_nonlin=softmax, concat_h=dae_dict['concat_h'],
                       noise=dae_dict['noise'], n_filters=dae_dict['n_filters'],
                       conv_before_pool=dae_dict['conv_before_pool'],
                       additional_pool=dae_dict['additional_pool'],
                       dropout=dae_dict['dropout'], skip=dae_dict['skip'],
                       unpool_type=dae_dict['unpool_type'],
                       bn=dae_dict['bn'], ae_h=ae_h)
    elif dae_dict['kind'] == 'fcn8':
        dae = buildFCN8_DAE(input_concat_h_vars, input_mask_var, n_classes,
                            nb_in_channels=n_classes, trainable=True,
                            load_weights=resume, pretrained=True, pascal=True,
                            concat_h=dae_dict['concat_h'], noise=dae_dict['noise'],
                            dropout=dae_dict['dropout'],
                            path_weights=os.path.join('/'.join(loadpath_init.split('/')[:-1]),
                            dae_dict['path_weights']),
                            model_name='dae_model_best.npz')
    elif dae_dict['kind'] == 'contextmod':
        dae = buildDAE_contextmod(input_concat_h_vars, input_mask_var, n_classes,
                                  path_weights=loadpath_init,
                                  model_name='dae_model.npz',
                                  trainable=True, load_weights=resume,
                                  out_nonlin=softmax, noise=dae_dict['noise'],
                                  concat_h=dae_dict['concat_h'])
    else:
        raise ValueError('Unknown dae kind')

    #
    # Define and compile theano functions
    #

    # training functions
    print "Defining and compiling training functions"

    # fcn prediction
    fcn_prediction = lasagne.layers.get_output(fcn, deterministic=True, batch_norm_use_averages=False)

    # select prediction layers (pooling and upsampling layers)
    dae_all_lays = lasagne.layers.get_all_layers(dae)
    if dae_dict['kind'] != 'contextmod':
        dae_lays = [l for l in dae_all_lays
                    if isinstance(l, Pool2DLayer) or
                    isinstance(l, CroppingLayer) or
                    isinstance(l, ElemwiseSumLayer) or
                    l == dae_all_lays[-1]]
        # dae_lays = dae_lays[::2]
    else:
        dae_lays = [l for l in dae_all_lays if isinstance(l, DilatedConv2DLayer) or l == dae_all_lays[-1]]

    if ae_h:
        h_ae_idx = [i for i, el in enumerate(dae_lays) if el.name == 'h_to_recon'][0]
        h_hat_idx = [i for i, el in enumerate(dae_lays) if el.name == 'h_hat'][0]

    # predictions
    dae_prediction_all = lasagne.layers.get_output(dae_lays,
                                                   batch_norm_use_averages=False)
    dae_prediction = dae_prediction_all[-1]
    dae_prediction_h = dae_prediction_all[:-1]

    test_dae_prediction_all = lasagne.layers.get_output(dae_lays,
                                                        deterministic=True,
                                                        batch_norm_use_averages=False)
    test_dae_prediction = test_dae_prediction_all[-1]
    test_dae_prediction_h = test_dae_prediction_all[:-1]

    # fetch h and h_hat if needed
    if ae_h:
        h = dae_prediction_all[h_ae_idx]
        h_hat = dae_prediction_all[h_hat_idx]
        h_test = test_dae_prediction_all[h_ae_idx]
        h_hat_test = test_dae_prediction_all[h_hat_idx]

    # loss
    loss = 0
    test_loss = 0

    # Convert DAE prediction to 2D
    dae_prediction_2D = dae_prediction.dimshuffle((0, 2, 3, 1))
    sh = dae_prediction_2D.shape
    dae_prediction_2D = dae_prediction_2D.reshape((T.prod(sh[:3]), sh[3]))

    test_dae_prediction_2D = test_dae_prediction.dimshuffle((0, 2, 3, 1))
    sh = test_dae_prediction_2D.shape
    test_dae_prediction_2D = test_dae_prediction_2D.reshape((T.prod(sh[:3]),
                                                            sh[3]))
    # Convert target to 2D
    target_var_2D = target_var.dimshuffle((0, 2, 3, 1))
    sh = target_var_2D.shape
    target_var_2D = target_var_2D.reshape((T.prod(sh[:3]), sh[3]))

    if 'crossentropy' in training_loss:
        # Compute loss
        loss += crossentropy(dae_prediction_2D, target_var_2D, void_labels,
                             one_hot=True)
        test_loss += crossentropy(test_dae_prediction_2D, target_var_2D,
                                  void_labels, one_hot=True)
    if 'dice' in training_loss:
        loss += dice_loss(dae_prediction, target_var, void_labels)
        test_loss += dice_loss(test_dae_prediction, target_var, void_labels)

    test_mse_loss = squared_error(test_dae_prediction, target_var, void)
    if 'squared_error' in training_loss:
        mse_loss = squared_error(dae_prediction, target_var, void)
        loss += lmb*mse_loss
        test_loss += lmb*test_mse_loss

    # Add intermediate losses
    if 'squared_error_h' in training_loss:
        # extract input layers and create dictionary
        dae_input_lays = [l for l in dae_all_lays if isinstance(l, InputLayer)]
        inputs = {dae_input_lays[0]: target_var[:, :void, :, :], dae_input_lays[-1]:target_var[:, :void, :, :]}
        for idx, val in enumerate(input_concat_h_vars):
            inputs[dae_input_lays[idx+1]] = val

        test_dae_prediction_all_gt = lasagne.layers.get_output(dae_lays,
                                                               inputs=inputs,
                                                               deterministic=True,
                                                               batch_norm_use_averages=False)
        test_dae_prediction_h_gt = test_dae_prediction_all_gt[:-1]

        loss += squared_error_h(dae_prediction_h, test_dae_prediction_h_gt)
        test_loss += squared_error_h(test_dae_prediction_h, test_dae_prediction_h_gt)

    # compute jaccard
    jacc = jaccard(dae_prediction_2D, target_var_2D, n_classes, one_hot=True)
    test_jacc = jaccard(test_dae_prediction_2D, target_var_2D, n_classes, one_hot=True)

    # if reconstructing h add the corresponding loss terms
    if ae_h:
        loss += squared_error_L(h, h_hat).mean()
        test_loss += squared_error_L(h_test, h_hat_test).mean()


    # network parameters
    params = lasagne.layers.get_all_params(dae, trainable=True)

    # optimizer
    if optimizer == 'rmsprop':
        updates = lasagne.updates.rmsprop(loss, params, learning_rate=lr)
    elif optimizer == 'adam':
        updates = lasagne.updates.adam(loss, params, learning_rate=lr)
    else:
        raise ValueError('Unknown optimizer')

    # functions
    train_fn = theano.function(input_concat_h_vars + [input_mask_var, target_var],
                               loss, updates=updates)

    fcn_fn = theano.function([input_x_var], fcn_prediction)
    val_fn = theano.function(input_concat_h_vars + [input_mask_var, target_var], [test_loss, test_jacc, test_mse_loss])

    err_train = []
    err_valid = []
    jacc_val_arr = []
    mse_val_arr = []
    patience = 0

    #
    # Train
    #
    # Training main loop
    print "Start training"
    for epoch in range(num_epochs):
        # Single epoch training and validation
        start_time = time.time()

        cost_train_tot = 0
        # Train
        for i in range(n_batches_train):
            # Get minibatch
            X_train_batch, L_train_batch = train_iter.next()
            L_train_batch = L_train_batch.astype(_FLOATX)

            #####uncomment if you want to control the feasability of pooling####
            # max_n_possible_pool = np.floor(np.log2(np.array(X_train_batch.shape[2:]).min()))
            # # check if we don't ask for more poolings than possible
            # assert n_pool+additional_pool < max_n_possible_pool
            ####################################################################

            # h prediction
            H_pred_batch = fcn_fn(X_train_batch)

            if dae_dict['from_gt']:
                Y_pred_batch = L_train_batch[:, :void, :, :]
            else:
                Y_pred_batch = H_pred_batch[-1]
                H_pred_batch = H_pred_batch[:-1]

            # Training step
            cost_train = train_fn(*(H_pred_batch + [Y_pred_batch, L_train_batch]))
            cost_train_tot += cost_train

        err_train += [cost_train_tot / n_batches_train]

        # Validation
        cost_val_tot = 0
        jacc_val_tot = 0
        mse_val_tot = 0
        for i in range(n_batches_val):
            # Get minibatch
            X_val_batch, L_val_batch = val_iter.next()
            L_val_batch = L_val_batch.astype(_FLOATX)

            # h prediction
            H_pred_batch = fcn_fn(X_val_batch)

            if dae_dict['from_gt']:
                Y_pred_batch = L_val_batch[:, :void, :, :]
            else:
                Y_pred_batch = H_pred_batch[-1]
                H_pred_batch = H_pred_batch[:-1]

            # Validation step
            cost_val, jacc_val, mse_val = val_fn(*(H_pred_batch + [Y_pred_batch, L_val_batch]))
            cost_val_tot += cost_val
            jacc_val_tot += jacc_val
            mse_val_tot += mse_val

        err_valid += [cost_val_tot / n_batches_val]
        jacc_val_arr += [np.mean(jacc_val_tot[0, :] / jacc_val_tot[1, :])]
        mse_val_arr += [mse_val_tot /  n_batches_val]

        out_str = "EPOCH %i: Avg epoch training cost train %f, cost val %f," + \
                  " jacc val %f, mse val % f took %f s"
        out_str = out_str % (epoch, err_train[epoch],
                             err_valid[epoch],
                             jacc_val_arr[epoch],
                             mse_val_arr[epoch],
                             time.time() - start_time)
        print out_str

        with open(os.path.join(savepath, "output.log"), "a") as f:
            f.write(out_str + "\n")

        # update learning rate
        lr.set_value(float(lr.get_value() * lr_anneal))

        # Early stopping and saving stuff
        if epoch == 0:
            best_err_val = err_valid[epoch]
            best_jacc_val = jacc_val_arr[epoch]
            best_mse_val = mse_val_arr[epoch]
        elif epoch > 0  and err_valid[epoch] < best_err_val:
            best_err_val = err_valid[epoch]
            best_jacc_val = jacc_val_arr[epoch]
            best_mse_val = mse_val_arr[epoch]
            patience = 0
            np.savez(os.path.join(savepath, 'dae_model_best.npz'),
                     *lasagne.layers.get_all_param_values(dae))
            np.savez(os.path.join(savepath, 'dae_errors_best.npz'),
                     err_train, err_valid, jacc_val_arr, mse_val_arr)
        else:
            patience += 1
            np.savez(os.path.join(savepath, 'dae_model_last.npz'),
                     *lasagne.layers.get_all_param_values(dae))
            np.savez(os.path.join(savepath, 'dae_errors_last.npz'),
                     err_train, err_valid, jacc_val_arr, mse_val_arr)

        # Finish training if patience has expired or max nber of epochs
        # reached
        if patience == max_patience or epoch == num_epochs - 1:
            # Copy files to loadpath
            if savepath != loadpath:
                print('Copying model and other training files to {}'.format(
                    loadpath))
                copy_tree(savepath, loadpath)
            # End
            print(' Training Done !')
            return