def inference(dataset, segm_net, learn_step=0.005, num_iter=500, dae_dict_updates= {}, training_dict={}, data_augmentation=False, which_set='test', ae_h=False, full_im_ft=False, savepath=None, loadpath=None, test_from_0_255=False): # # Update DAE parameters # dae_dict = {'kind': 'fcn8', 'dropout': 0.0, 'skip': True, 'unpool_type':'standard', 'n_filters': 64, 'conv_before_pool': 1, 'additional_pool': 0, 'concat_h': ['input'], 'noise': 0.0, 'from_gt': True, 'temperature': 1.0, 'layer': 'probs_dimshuffle', 'exp_name': '', 'bn': 0} dae_dict.update(dae_dict_updates) # # Prepare load/save directories # exp_name = build_experiment_name(segm_net, data_aug=data_augmentation, ae_h=ae_h, **dict(dae_dict.items() + training_dict.items())) # exp_name += '_ftsmall' if full_im_ft else '' if savepath is None: raise ValueError('A saving directory must be specified') savepath = os.path.join(savepath, dataset, exp_name, 'img_plots', which_set) loadpath = os.path.join(loadpath, dataset, exp_name) if not os.path.exists(savepath): os.makedirs(savepath) else: print('\033[93m The following folder already exists {}. ' 'It will be overwritten in a few seconds...\033[0m'.format( savepath)) print('Saving directory : ' + savepath) with open(os.path.join(savepath, "config.txt"), "w") as f: for key, value in locals().items(): f.write('{} = {}\n'.format(key, value)) # # Define symbolic variables # input_x_var = T.tensor4('input_x_var') # tensor for input image batch input_concat_h_vars = [T.tensor4()] * len(dae_dict['concat_h']) # tensor for hidden repr batch (input dae) y_hat_var = T.tensor4('pred_y_var') target_var = T.tensor4('target_var') # tensor for target batch # # Build dataset iterator # data_iter = load_data(dataset, {}, one_hot=True, batch_size=[10, 5, 10], return_0_255=test_from_0_255, which_set=which_set) colors = data_iter.cmap n_batches_test = data_iter.nbatches n_classes = data_iter.non_void_nclasses void_labels = data_iter.void_labels nb_in_channels = data_iter.data_shape[0] void = n_classes if any(void_labels) else n_classes+1 # # Build networks # # Build segmentation network print 'Building segmentation network' if segm_net == 'fcn8': fcn = buildFCN8(nb_in_channels, input_var=input_x_var, n_classes=n_classes, void_labels=void_labels, path_weights=WEIGHTS_PATH+dataset+'/fcn8_model.npz', trainable=False, load_weights=True, layer=dae_dict['concat_h']+[dae_dict['layer']]) padding = 100 elif segm_net == 'densenet': fcn = build_fcdensenet(input_x_var, nb_in_channels=nb_in_channels, n_classes=n_classes, layer=dae_dict['concat_h']) padding = 0 elif segm_net == 'fcn_fcresnet': raise NotImplementedError else: raise ValueError # Build DAE with pre-trained weights print 'Building DAE network' if dae_dict['kind'] == 'standard': nb_features_to_concat=fcn[0].output_shape[1] dae = buildDAE(input_concat_h_vars, y_hat_var, n_classes, nb_features_to_concat=nb_features_to_concat, padding=padding, trainable=True, void_labels=void_labels, load_weights=True, path_weights=loadpath, model_name='dae_model_best.npz', out_nonlin=softmax, concat_h=dae_dict['concat_h'], noise=dae_dict['noise'], n_filters=dae_dict['n_filters'], conv_before_pool=dae_dict['conv_before_pool'], additional_pool=dae_dict['additional_pool'], dropout=dae_dict['dropout'], skip=dae_dict['skip'], unpool_type=dae_dict['unpool_type'], bn=dae_dict['bn']) elif dae_dict['kind'] == 'fcn8': dae = buildFCN8_DAE(input_concat_h_vars, y_hat_var, n_classes, nb_in_channels=n_classes, path_weights=loadpath, model_name='dae_model_best.npz', trainable=True, load_weights=True, pretrained=True, pascal=False, concat_h=dae_dict['concat_h'], noise=dae_dict['noise']) elif dae_dict['kind'] == 'contextmod': dae = buildDAE_contextmod(input_concat_h_vars, y_hat_var, n_classes, path_weights=loadpath, model_name='dae_model_best.npz', trainable=True, load_weights=True, out_nonlin=softmax, noise=dae_dict['noise'], concat_h=dae_dict['concat_h']) else: raise ValueError('Unknown dae kind') # # Define and compile theano functions # print "Defining and compiling theano functions" # predictions and theano functions pred_fcn = lasagne.layers.get_output(fcn, deterministic=True, batch_norm_use_averages=False) pred_fcn_fn = theano.function([input_x_var], pred_fcn) pred_dae = lasagne.layers.get_output(dae, deterministic=True) pred_dae_fn = theano.function(input_concat_h_vars+[y_hat_var], pred_dae) # Reshape iterative inference output to b01,c y_hat_dimshuffle = y_hat_var.dimshuffle((0, 2, 3, 1)) sh = y_hat_dimshuffle.shape y_hat_2D = y_hat_dimshuffle.reshape((T.prod(sh[:3]), sh[3])) # Reshape iterative inference output to b01,c target_var_dimshuffle = target_var.dimshuffle((0, 2, 3, 1)) sh2 = target_var_dimshuffle.shape target_var_2D = target_var_dimshuffle.reshape((T.prod(sh2[:3]), sh2[3])) # derivative of energy wrt input and theano function de = - (pred_dae - y_hat_var) de_fn = theano.function(input_concat_h_vars+[y_hat_var], de) # metrics and theano functions test_loss = squared_error(y_hat_var, target_var, void) test_acc = accuracy(y_hat_2D, target_var_2D, void_labels, one_hot=True) test_jacc = jaccard(y_hat_2D, target_var_2D, n_classes, one_hot=True) val_fn = theano.function([y_hat_var, target_var], [test_acc, test_jacc, test_loss]) # # Infer # print 'Start infering' rec_tot = 0 rec_tot_fcn = 0 rec_tot_dae = 0 acc_tot = 0 acc_tot_fcn = 0 jacc_tot = 0 jacc_tot_fcn = 0 acc_tot_dae = 0 jacc_tot_dae = 0 print 'Inference step: '+str(learn_step)+ 'num iter '+str(num_iter) for i in range(n_batches_test): info_str = "Batch %d out of %d" % (i+1, n_batches_test) print '-'*30 print '*'*5 + info_str + '*'*5 print '-'*30 # Get minibatch X_test_batch, L_test_batch = data_iter.next() L_test_batch = L_test_batch.astype(_FLOATX) # Compute fcn prediction y and h pred_test_batch = pred_fcn_fn(X_test_batch) Y_test_batch = pred_test_batch[-1] H_test_batch = pred_test_batch[:-1] # Compute metrics before iterative inference acc_fcn, jacc_fcn, rec_fcn = val_fn(Y_test_batch, L_test_batch) acc_tot_fcn += acc_fcn jacc_tot_fcn += jacc_fcn rec_tot_fcn += rec_fcn Y_test_batch_fcn = Y_test_batch print_results('>>>>> FCN:', rec_tot_fcn, acc_tot_fcn, jacc_tot_fcn, i+1) # Compute dae output and metrics after dae Y_test_batch_dae = pred_dae_fn(*(H_test_batch+[Y_test_batch])) acc_dae, jacc_dae, rec_dae = val_fn(Y_test_batch_dae, L_test_batch) acc_tot_dae += acc_dae jacc_tot_dae += jacc_dae rec_tot_dae += rec_dae print_results('>>>>> FCN+DAE:', rec_tot_dae, acc_tot_dae, jacc_tot_dae, i+1) Y_test_batch_ii = [] for im in range(X_test_batch.shape[0]): print('-----------------------') h_im = [el[np.newaxis, im] for el in H_test_batch] y_im = Y_test_batch[np.newaxis, im] t_im = L_test_batch[np.newaxis, im] # Iterative inference for it in range(num_iter): # Compute gradient grad = de_fn(*(h_im+[y_im])) # Update prediction y_im = y_im - learn_step * grad # Clip prediction y_im = np.clip(y_im, 0.0, 1.0) norm = np.linalg.norm(grad, axis=1).mean() if norm < _EPSILON: break acc_iter, jacc_iter, rec_iter = val_fn(y_im, t_im) print rec_iter, acc_iter, np.nanmean(jacc_iter[0, :]/jacc_iter[1, :]) Y_test_batch_ii += [y_im] Y_test_batch_ii = np.concatenate(Y_test_batch_ii, axis=0) # Compute metrics acc, jacc, rec = val_fn(Y_test_batch_ii, L_test_batch) acc_tot += acc jacc_tot += jacc rec_tot += rec print_results('>>>>> ITERATIVE INFERENCE:', rec_tot, acc_tot, jacc_tot, i+1) np.savez(savepath+'batch'+str(i)+'.npz', X=X_test_batch, L=L_test_batch, Y_ii=Y_test_batch_ii, Y_fcn=Y_test_batch_fcn) # Save images # save_img(X_test_batch, # L_test_batch, # Y_test_batch_ii, # Y_test_batch_fcn, # savepath, 'batch' + str(i), # void_labels, colors) # Print summary of how things went print('-------------------------------------------------------------------') print('------------------------------SUMMARY------------------------------') print('-------------------------------------------------------------------') print_results('>>>>> FCN:', rec_tot_fcn, acc_tot_fcn, jacc_tot_fcn, i+1) print_results('>>>>> FCN+DAE:', rec_tot_dae, acc_tot_dae, jacc_tot_dae, i+1) print_results('>>>>> ITERATIVE INFERENCE:', rec_tot, acc_tot, jacc_tot, i+1) # Compute per class jaccard jacc_perclass_fcn = jacc_tot_fcn[0, :]/jacc_tot_fcn[1, :] jacc_perclass = jacc_tot[0, :]/jacc_tot[1, :] print ">>>>> Per class jaccard:" labs = data_iter.mask_labels for i in range(len(labs)-len(void_labels)): class_str = ' ' + labs[i] + ' : fcn -> %f, ii -> %f' class_str = class_str % (jacc_perclass_fcn[i], jacc_perclass[i]) print class_str # Move segmentations if savepath != loadpath: print('Copying images to {}'.format(loadpath)) copy_tree(savepath, os.path.join(loadpath, 'img_plots', which_set))
def test(dataset, segm_net, which_set='val', data_aug=False, savepath=None, loadpath=None, test_from_0_255=False): # # Define symbolic variables # input_x_var = T.tensor4('input_var') target_var = T.tensor4('target_var') # # Build dataset iterator # data_iter = load_data(dataset, {}, one_hot=True, batch_size=[10, 10, 10], return_0_255=test_from_0_255, which_set=which_set) colors = data_iter.cmap n_batches_test = data_iter.nbatches n_classes = data_iter.non_void_nclasses void_labels = data_iter.void_labels nb_in_channels = data_iter.data_shape[0] void = n_classes if any(void_labels) else n_classes+1 # # Build segmentation network # print ' Building segmentation network' if segm_net == 'fcn8': fcn = buildFCN8(nb_in_channels, input_var=input_x_var, n_classes=n_classes, void_labels=void_labels, path_weights=WEIGHTS_PATH+dataset+'/fcn8_model.npz', trainable=False, load_weights=True, layer=['probs_dimshuffle']) elif segm_net == 'densenet': fcn = build_fcdensenet(input_x_var, nb_in_channels=nb_in_channels, n_classes=n_classes, layer=[], output_d='4d') elif segm_net == 'fcn_fcresnet': raise NotImplementedError else: raise ValueError # # Define and compile theano functions # print "Defining and compiling test functions" test_prediction = lasagne.layers.get_output(fcn, deterministic=True, batch_norm_use_averages=False)[0] # Reshape iterative inference output to b01,c test_prediction_dimshuffle = test_prediction.dimshuffle((0, 2, 3, 1)) sh = test_prediction_dimshuffle.shape test_prediction_2D = test_prediction_dimshuffle.reshape((T.prod(sh[:3]), sh[3])) # Reshape iterative inference output to b01,c target_var_dimshuffle = target_var.dimshuffle((0, 2, 3, 1)) sh2 = target_var_dimshuffle.shape target_var_2D = target_var_dimshuffle.reshape((T.prod(sh2[:3]), sh2[3])) test_loss = squared_error(test_prediction, target_var, void) test_acc = accuracy(test_prediction_2D, target_var_2D, void_labels, one_hot=True) test_jacc = jaccard(test_prediction_2D, target_var_2D, n_classes, one_hot=True) val_fn = theano.function([input_x_var, target_var], [test_acc, test_jacc, test_loss]) pred_fcn_fn = theano.function([input_x_var], test_prediction) # Iterate over test and compute metrics print "Testing" acc_test_tot = 0 mse_test_tot = 0 jacc_num_test_tot = np.zeros((1, n_classes)) jacc_denom_test_tot = np.zeros((1, n_classes)) for i in range(n_batches_test): # Get minibatch X_test_batch, L_test_batch = data_iter.next() Y_test_batch = pred_fcn_fn(X_test_batch) L_test_batch = L_test_batch.astype(_FLOATX) # L_test_batch = np.reshape(L_test_batch, np.prod(L_test_batch.shape)) # Test step acc_test, jacc_test, mse_test = val_fn(X_test_batch, L_test_batch) jacc_num_test, jacc_denom_test = jacc_test acc_test_tot += acc_test mse_test_tot += mse_test jacc_num_test_tot += jacc_num_test jacc_denom_test_tot += jacc_denom_test # Save images # save_img(X_test_batch, L_test_batch, Y_test_batch, # savepath, n_classes, 'batch' + str(i), # void_labels, colors) acc_test = acc_test_tot/n_batches_test mse_test = mse_test_tot/n_batches_test jacc_per_class = jacc_num_test_tot / jacc_denom_test_tot jacc_per_class = jacc_per_class[0] jacc_test = np.mean(jacc_per_class) out_str = "FINAL MODEL: acc test %f, jacc test %f, mse test %f" out_str = out_str % (acc_test, jacc_test, mse_test) print out_str print ">>> Per class jaccard:" labs = data_iter.mask_labels for i in range(len(labs)-len(void_labels)): class_str = ' ' + labs[i] + ' : %f' class_str = class_str % (jacc_per_class[i]) print class_str
def inference(dataset, segm_net, which_set='val', num_iter=5, Bilateral=True, savepath=None, loadpath=None, test_from_0_255=False): # # Define symbolic variables # input_x_var = T.tensor4('input_x_var') y_hat_var = T.tensor4('pred_y_var') target_var = T.tensor4('target_var') # # Build dataset iterator # data_iter = load_data(dataset, {}, one_hot=True, batch_size=[10, 10, 10], return_0_255=test_from_0_255, which_set=which_set) colors = data_iter.cmap n_batches_test = data_iter.nbatches n_classes = data_iter.non_void_nclasses void_labels = data_iter.void_labels # # Prepare saving directory # savepath = os.path.join(savepath, dataset, segm_net, 'img_plots', 'crf', str(num_iter), which_set) loadpath = os.path.join(loadpath, dataset, segm_net, 'img_plots', 'crf', str(num_iter), which_set) if not os.path.exists(savepath): os.makedirs(savepath) # # Build network # print 'Building segmentation network' if segm_net == 'fcn8': fcn = buildFCN8(3, input_var=input_x_var, n_classes=n_classes, void_labels=void_labels, path_weights=WEIGHTS_PATH + dataset + '/fcn8_model.npz', trainable=False, load_weights=True, layer=['probs_dimshuffle']) padding = 100 elif segm_net == 'densenet': fcn = build_fcdensenet(input_x_var, nb_in_channels=3, n_classes=n_classes, layer=[]) padding = 0 elif segm_net == 'fcn_fcresnet': raise NotImplementedError else: raise ValueError # # Define and compile theano functions # print "Defining and compiling theano functions" # predictions of fcn pred_fcn = lasagne.layers.get_output(fcn, deterministic=True, batch_norm_use_averages=False)[0] # function to compute output of fcn pred_fcn_fn = theano.function([input_x_var], pred_fcn) # reshape fcn output to b,01c y_hat_dimshuffle = y_hat_var.dimshuffle((0, 2, 3, 1)) sh = y_hat_dimshuffle.shape y_hat_2D = y_hat_dimshuffle.reshape((T.prod(sh[:3]), sh[3])) # reshape target to b01,c target_var_dimshuffle = target_var.dimshuffle((0, 2, 3, 1)) sh2 = target_var_dimshuffle.shape target_var_2D = target_var_dimshuffle.reshape((T.prod(sh2[:3]), sh2[3])) # metrics to evaluate iterative inference test_acc = accuracy(y_hat_2D, target_var_2D, void_labels, one_hot=True) test_jacc = jaccard(y_hat_2D, target_var_2D, n_classes, one_hot=True) # functions to compute metrics val_fn = theano.function([y_hat_var, target_var], [test_acc, test_jacc]) # # Infer # print 'Start infering' acc_tot_crf = 0 acc_tot_fcn = 0 jacc_tot_crf = 0 jacc_tot_fcn = 0 for i in range(n_batches_test): info_str = "Batch %d out of %d" % (i + 1, n_batches_test) print info_str # Get minibatch X_test_batch, L_test_batch = data_iter.next() L_test_batch = L_test_batch.astype(_FLOATX) # Compute fcn prediction Y_test_batch = pred_fcn_fn(X_test_batch) # Compute metrics before CRF acc_fcn, jacc_fcn = val_fn(Y_test_batch, L_test_batch) acc_tot_fcn += acc_fcn jacc_tot_fcn += jacc_fcn Y_test_batch_fcn = Y_test_batch Y_test_batch_crf = [] for im in range(X_test_batch.shape[0]): # CRF d = dcrf.DenseCRF2D(Y_test_batch.shape[3], Y_test_batch.shape[2], n_classes) sm = Y_test_batch[im, 0:n_classes, :, :] sm = sm.reshape((n_classes, -1)) img = X_test_batch[im] img = np.transpose(img, (1, 2, 0)) img = (255 * img).astype('uint8') img2 = np.asarray(img, order='C') # set unary potentials (neg log probability) U = unary_from_softmax(sm) d.setUnaryEnergy(U) # set pairwise potentials # This adds the color-independent term, features are the # locations only. Smoothness kernel. # sxy: gaussian x, y std # compat: ways to weight contributions, a number for potts compatibility, # vector for diagonal compatibility, an array for matrix compatibility # kernel: kernel used, CONST_KERNEL, FULL_KERNEL, DIAG_KERNEL # normalization: NORMALIZE_AFTER, NORMALIZE_BEFORE, # NO_NORMALIZAITION, NORMALIZE_SYMMETRIC d.addPairwiseGaussian(sxy=(3, 3), compat=3, kernel=dcrf.DIAG_KERNEL, normalization=dcrf.NORMALIZE_SYMMETRIC) if Bilateral: # Appearance kernel. This adds the color-dependent term, i.e. features # are (x,y,r,g,b). # im is an image-array, e.g. im.dtype == np.uint8 and im.shape == (640,480,3) # to set sxy and srgb perform grid search on validation set d.addPairwiseBilateral(sxy=(3, 3), srgb=(13, 13, 13), rgbim=img2, compat=10, kernel=dcrf.DIAG_KERNEL, normalization=dcrf.NORMALIZE_SYMMETRIC) # inference Q = d.inference(num_iter) Q = np.reshape( Q, (n_classes, Y_test_batch.shape[2], Y_test_batch.shape[3])) Y_test_batch_crf += [np.expand_dims(Q, axis=0)] # Save images Y_test_batch = np.concatenate(Y_test_batch_crf, axis=0) # Compute metrics after CRF acc_crf, jacc_crf = val_fn(Y_test_batch, L_test_batch) acc_tot_crf += acc_crf jacc_tot_crf += jacc_crf # save_img(X_test_batch.astype(_FLOATX), L_test_batch, Y_test_batch, # Y_test_batch_fcn, savepath, 'batch' + str(i), void_labels, colors) np.savez(savepath + 'batch' + str(i) + '.npz', X=X_test_batch, L=L_test_batch, Y_crf=Y_test_batch, Y_fcn=Y_test_batch_fcn) acc_test_crf = acc_tot_crf / (n_batches_test) jacc_test_perclass_crf = jacc_tot_crf[0, :] / jacc_tot_crf[1, :] jacc_test_crf = np.nanmean(jacc_test_perclass_crf) acc_test_fcn = acc_tot_fcn / n_batches_test jacc_test_perclass_fcn = jacc_tot_fcn[0, :] / jacc_tot_fcn[1, :] jacc_test_fcn = np.nanmean(jacc_test_perclass_fcn) out_str = "TEST: acc crf %f, jacc crf %f, acc fcn %f, jacc fcn %f" out_str = out_str % (acc_test_crf, jacc_test_crf, acc_test_fcn, jacc_test_fcn) print ">>>>> Per class jaccard:" labs = data_iter.mask_labels for i in range(len(labs) - len(void_labels)): class_str = ' ' + labs[i] + ' : fcn -> %f, crf -> %f' class_str = class_str % (jacc_test_perclass_fcn[i], jacc_test_perclass_crf[i]) print class_str print out_str # Move segmentations if savepath != loadpath: print('Copying images to {}'.format(loadpath)) copy_tree(savepath, loadpath) return jacc_test_perclass_crf
def train(dataset, segm_net, learning_rate=0.005, lr_anneal=1.0, weight_decay=1e-4, num_epochs=500, max_patience=100, optimizer='rmsprop', training_loss=['squared_error'], batch_size=[10, 1, 1], ae_h=False, dae_dict_updates={}, data_augmentation={}, savepath=None, loadpath=None, resume=False, train_from_0_255=False, lmb=1, full_im_ft=False): # # Update DAE parameters # dae_dict = {'kind': 'fcn8', 'dropout': 0.0, 'skip': True, 'unpool_type': 'standard', 'n_filters': 64, 'conv_before_pool': 1, 'additional_pool': 0, 'concat_h': ['input'], 'noise': 0.0, 'from_gt': True, 'temperature': 1.0, 'path_weights': '', 'layer': 'probs_dimshuffle', 'exp_name': '', 'bn': 0} dae_dict.update(dae_dict_updates) # # Prepare load/save directories # exp_name = build_experiment_name(segm_net, training_loss=training_loss, data_aug=bool(data_augmentation), learning_rate=learning_rate, lr_anneal=lr_anneal, weight_decay=weight_decay, optimizer=optimizer, ae_h=ae_h, **dae_dict) if savepath is None: raise ValueError('A saving directory must be specified') loadpath_init = os.path.join(loadpath, dataset, exp_name) exp_name += '_ft' if full_im_ft else '' loadpath = os.path.join(loadpath, dataset, exp_name) savepath = os.path.join(savepath, dataset, exp_name) if not os.path.exists(savepath): os.makedirs(savepath) else: print('\033[93m The following folder already exists {}. ' 'It will be overwritten in a few seconds...\033[0m'.format( savepath)) print('Saving directory : ' + savepath) with open(os.path.join(savepath, "config.txt"), "w") as f: for key, value in locals().items(): f.write('{} = {}\n'.format(key, value)) # # Define symbolic variables # input_x_var = T.tensor4('input_x_var') # tensor for input image batch input_mask_var = T.tensor4('input_mask_var') # tensor for segmentation bach (input dae) input_concat_h_vars = [T.tensor4()] * len(dae_dict['concat_h']) # tensor for hidden repr batch (input dae) target_var = T.tensor4('target_var') # tensor for target batch # learning_rate = learning_rate*0.1 if full_im_ft else learning_rate # learning_rate = 0.01 print learning_rate lr = theano.shared(np.float32(learning_rate), 'learning_rate') # # Build dataset iterator # train_iter, val_iter, _ = load_data(dataset, data_augmentation, one_hot=True, batch_size=batch_size, return_0_255=train_from_0_255, ) n_batches_train = train_iter.nbatches n_batches_val = val_iter.nbatches n_classes = train_iter.non_void_nclasses void_labels = train_iter.void_labels nb_in_channels = train_iter.data_shape[0] void = n_classes if any(void_labels) else n_classes+1 # # Build networks # # Check that model and dataset get along print 'Checking options' assert (segm_net == 'fcn8' and dataset == 'camvid') or \ (segm_net == 'densenet' and dataset == 'camvid') assert (data_augmentation['crop_size'] == None and full_im_ft) or not full_im_ft # Build segmentation network print 'Building segmentation network' if segm_net == 'fcn8': layer_out = copy.copy(dae_dict['concat_h']) layer_out += [copy.copy(dae_dict['layer'])] if not dae_dict['from_gt'] else [] fcn = buildFCN8(nb_in_channels, input_x_var, n_classes=n_classes, void_labels=void_labels, path_weights=WEIGHTS_PATH+dataset+'/fcn8_model.npz', load_weights=True, layer=layer_out) padding = 100 elif segm_net == 'densenet': fcn = build_fcdensenet(input_x_var, nb_in_channels=nb_in_channels, n_classes=n_classes, layer=dae_dict['concat_h'], from_gt=dae_dict['from_gt']) padding = 0 elif segm_net == 'fcn_fcresnet': raise NotImplementedError else: raise ValueError # Build DAE network print 'Building DAE network' if ae_h and dae_dict['kind'] != 'standard': raise ValueError('Plug&Play not implemented for ' + dae_dict['kind']) if ae_h and 'pool' not in dae_dict['concat_h'][-1]: raise ValueError('Plug&Play version needs concat_h to be different than input') ae_h = ae_h and 'pool' in dae_dict['concat_h'][-1] if dae_dict['kind'] == 'standard': nb_features_to_concat=fcn[0].output_shape[1] dae = buildDAE(input_concat_h_vars, input_mask_var, n_classes, nb_features_to_concat=nb_features_to_concat, padding=padding, trainable=True, void_labels=void_labels, load_weights=resume or full_im_ft, path_weights=loadpath_init, model_name='dae_model_best.npz', out_nonlin=softmax, concat_h=dae_dict['concat_h'], noise=dae_dict['noise'], n_filters=dae_dict['n_filters'], conv_before_pool=dae_dict['conv_before_pool'], additional_pool=dae_dict['additional_pool'], dropout=dae_dict['dropout'], skip=dae_dict['skip'], unpool_type=dae_dict['unpool_type'], bn=dae_dict['bn'], ae_h=ae_h) elif dae_dict['kind'] == 'fcn8': dae = buildFCN8_DAE(input_concat_h_vars, input_mask_var, n_classes, nb_in_channels=n_classes, trainable=True, load_weights=resume, pretrained=True, pascal=True, concat_h=dae_dict['concat_h'], noise=dae_dict['noise'], dropout=dae_dict['dropout'], path_weights=os.path.join('/'.join(loadpath_init.split('/')[:-1]), dae_dict['path_weights']), model_name='dae_model_best.npz') elif dae_dict['kind'] == 'contextmod': dae = buildDAE_contextmod(input_concat_h_vars, input_mask_var, n_classes, path_weights=loadpath_init, model_name='dae_model.npz', trainable=True, load_weights=resume, out_nonlin=softmax, noise=dae_dict['noise'], concat_h=dae_dict['concat_h']) else: raise ValueError('Unknown dae kind') # # Define and compile theano functions # # training functions print "Defining and compiling training functions" # fcn prediction fcn_prediction = lasagne.layers.get_output(fcn, deterministic=True, batch_norm_use_averages=False) # select prediction layers (pooling and upsampling layers) dae_all_lays = lasagne.layers.get_all_layers(dae) if dae_dict['kind'] != 'contextmod': dae_lays = [l for l in dae_all_lays if isinstance(l, Pool2DLayer) or isinstance(l, CroppingLayer) or isinstance(l, ElemwiseSumLayer) or l == dae_all_lays[-1]] # dae_lays = dae_lays[::2] else: dae_lays = [l for l in dae_all_lays if isinstance(l, DilatedConv2DLayer) or l == dae_all_lays[-1]] if ae_h: h_ae_idx = [i for i, el in enumerate(dae_lays) if el.name == 'h_to_recon'][0] h_hat_idx = [i for i, el in enumerate(dae_lays) if el.name == 'h_hat'][0] # predictions dae_prediction_all = lasagne.layers.get_output(dae_lays, batch_norm_use_averages=False) dae_prediction = dae_prediction_all[-1] dae_prediction_h = dae_prediction_all[:-1] test_dae_prediction_all = lasagne.layers.get_output(dae_lays, deterministic=True, batch_norm_use_averages=False) test_dae_prediction = test_dae_prediction_all[-1] test_dae_prediction_h = test_dae_prediction_all[:-1] # fetch h and h_hat if needed if ae_h: h = dae_prediction_all[h_ae_idx] h_hat = dae_prediction_all[h_hat_idx] h_test = test_dae_prediction_all[h_ae_idx] h_hat_test = test_dae_prediction_all[h_hat_idx] # loss loss = 0 test_loss = 0 # Convert DAE prediction to 2D dae_prediction_2D = dae_prediction.dimshuffle((0, 2, 3, 1)) sh = dae_prediction_2D.shape dae_prediction_2D = dae_prediction_2D.reshape((T.prod(sh[:3]), sh[3])) test_dae_prediction_2D = test_dae_prediction.dimshuffle((0, 2, 3, 1)) sh = test_dae_prediction_2D.shape test_dae_prediction_2D = test_dae_prediction_2D.reshape((T.prod(sh[:3]), sh[3])) # Convert target to 2D target_var_2D = target_var.dimshuffle((0, 2, 3, 1)) sh = target_var_2D.shape target_var_2D = target_var_2D.reshape((T.prod(sh[:3]), sh[3])) if 'crossentropy' in training_loss: # Compute loss loss += crossentropy(dae_prediction_2D, target_var_2D, void_labels, one_hot=True) test_loss += crossentropy(test_dae_prediction_2D, target_var_2D, void_labels, one_hot=True) if 'dice' in training_loss: loss += dice_loss(dae_prediction, target_var, void_labels) test_loss += dice_loss(test_dae_prediction, target_var, void_labels) test_mse_loss = squared_error(test_dae_prediction, target_var, void) if 'squared_error' in training_loss: mse_loss = squared_error(dae_prediction, target_var, void) loss += lmb*mse_loss test_loss += lmb*test_mse_loss # Add intermediate losses if 'squared_error_h' in training_loss: # extract input layers and create dictionary dae_input_lays = [l for l in dae_all_lays if isinstance(l, InputLayer)] inputs = {dae_input_lays[0]: target_var[:, :void, :, :], dae_input_lays[-1]:target_var[:, :void, :, :]} for idx, val in enumerate(input_concat_h_vars): inputs[dae_input_lays[idx+1]] = val test_dae_prediction_all_gt = lasagne.layers.get_output(dae_lays, inputs=inputs, deterministic=True, batch_norm_use_averages=False) test_dae_prediction_h_gt = test_dae_prediction_all_gt[:-1] loss += squared_error_h(dae_prediction_h, test_dae_prediction_h_gt) test_loss += squared_error_h(test_dae_prediction_h, test_dae_prediction_h_gt) # compute jaccard jacc = jaccard(dae_prediction_2D, target_var_2D, n_classes, one_hot=True) test_jacc = jaccard(test_dae_prediction_2D, target_var_2D, n_classes, one_hot=True) # if reconstructing h add the corresponding loss terms if ae_h: loss += squared_error_L(h, h_hat).mean() test_loss += squared_error_L(h_test, h_hat_test).mean() # network parameters params = lasagne.layers.get_all_params(dae, trainable=True) # optimizer if optimizer == 'rmsprop': updates = lasagne.updates.rmsprop(loss, params, learning_rate=lr) elif optimizer == 'adam': updates = lasagne.updates.adam(loss, params, learning_rate=lr) else: raise ValueError('Unknown optimizer') # functions train_fn = theano.function(input_concat_h_vars + [input_mask_var, target_var], loss, updates=updates) fcn_fn = theano.function([input_x_var], fcn_prediction) val_fn = theano.function(input_concat_h_vars + [input_mask_var, target_var], [test_loss, test_jacc, test_mse_loss]) err_train = [] err_valid = [] jacc_val_arr = [] mse_val_arr = [] patience = 0 # # Train # # Training main loop print "Start training" for epoch in range(num_epochs): # Single epoch training and validation start_time = time.time() cost_train_tot = 0 # Train for i in range(n_batches_train): # Get minibatch X_train_batch, L_train_batch = train_iter.next() L_train_batch = L_train_batch.astype(_FLOATX) #####uncomment if you want to control the feasability of pooling#### # max_n_possible_pool = np.floor(np.log2(np.array(X_train_batch.shape[2:]).min())) # # check if we don't ask for more poolings than possible # assert n_pool+additional_pool < max_n_possible_pool #################################################################### # h prediction H_pred_batch = fcn_fn(X_train_batch) if dae_dict['from_gt']: Y_pred_batch = L_train_batch[:, :void, :, :] else: Y_pred_batch = H_pred_batch[-1] H_pred_batch = H_pred_batch[:-1] # Training step cost_train = train_fn(*(H_pred_batch + [Y_pred_batch, L_train_batch])) cost_train_tot += cost_train err_train += [cost_train_tot / n_batches_train] # Validation cost_val_tot = 0 jacc_val_tot = 0 mse_val_tot = 0 for i in range(n_batches_val): # Get minibatch X_val_batch, L_val_batch = val_iter.next() L_val_batch = L_val_batch.astype(_FLOATX) # h prediction H_pred_batch = fcn_fn(X_val_batch) if dae_dict['from_gt']: Y_pred_batch = L_val_batch[:, :void, :, :] else: Y_pred_batch = H_pred_batch[-1] H_pred_batch = H_pred_batch[:-1] # Validation step cost_val, jacc_val, mse_val = val_fn(*(H_pred_batch + [Y_pred_batch, L_val_batch])) cost_val_tot += cost_val jacc_val_tot += jacc_val mse_val_tot += mse_val err_valid += [cost_val_tot / n_batches_val] jacc_val_arr += [np.mean(jacc_val_tot[0, :] / jacc_val_tot[1, :])] mse_val_arr += [mse_val_tot / n_batches_val] out_str = "EPOCH %i: Avg epoch training cost train %f, cost val %f," + \ " jacc val %f, mse val % f took %f s" out_str = out_str % (epoch, err_train[epoch], err_valid[epoch], jacc_val_arr[epoch], mse_val_arr[epoch], time.time() - start_time) print out_str with open(os.path.join(savepath, "output.log"), "a") as f: f.write(out_str + "\n") # update learning rate lr.set_value(float(lr.get_value() * lr_anneal)) # Early stopping and saving stuff if epoch == 0: best_err_val = err_valid[epoch] best_jacc_val = jacc_val_arr[epoch] best_mse_val = mse_val_arr[epoch] elif epoch > 0 and err_valid[epoch] < best_err_val: best_err_val = err_valid[epoch] best_jacc_val = jacc_val_arr[epoch] best_mse_val = mse_val_arr[epoch] patience = 0 np.savez(os.path.join(savepath, 'dae_model_best.npz'), *lasagne.layers.get_all_param_values(dae)) np.savez(os.path.join(savepath, 'dae_errors_best.npz'), err_train, err_valid, jacc_val_arr, mse_val_arr) else: patience += 1 np.savez(os.path.join(savepath, 'dae_model_last.npz'), *lasagne.layers.get_all_param_values(dae)) np.savez(os.path.join(savepath, 'dae_errors_last.npz'), err_train, err_valid, jacc_val_arr, mse_val_arr) # Finish training if patience has expired or max nber of epochs # reached if patience == max_patience or epoch == num_epochs - 1: # Copy files to loadpath if savepath != loadpath: print('Copying model and other training files to {}'.format( loadpath)) copy_tree(savepath, loadpath) # End print(' Training Done !') return