def retrieve_params(mlp_full, mlp_sub, dropout, exo_dropout=0.): # get the list of parameters assert exo_dropout >= 0 and exo_dropout < 1 assert dropout >= 0 and dropout < 1 if isinstance(mlp_sub, CompositeSequence): L_W_full, L_B_full = mlp_full.get_Params() L_W_sub, L_B_sub = mlp_sub.get_Params() else: L_W_full, L_B_full = mlp_full.get_Params(option="") x = T.matrix() cost = mlp_sub.apply(x).sum() cg = ComputationGraph(cost) W = VariableFilter(roles=[WEIGHT])(cg.variables) B = VariableFilter(roles=[BIAS])(cg.variables) if W[0].shape.eval()[1] != W[1].shape.eval()[0]: W.reverse() B.reverse() for w, b, index in zip(W, B, range(len(W))): w.name = "layer_" + str(index) + "_W" b.name = "layer_" + str(index) + "_B" L_W_sub = W L_B_sub = B # sort the lists here L_W_full.sort(key=lambda e: analyze_param_name(e.name)[1]) L_W_sub.sort(key=lambda e: analyze_param_name(e.name)[1]) L_B_full.sort(key=lambda e: analyze_param_name(e.name)[1]) L_B_sub.sort(key=lambda e: analyze_param_name(e.name)[1]) # get the splits splits_W, splits_B, rough_kinds = dropout_indices(L_W_full, L_B_full, dropout) # WHEN APPLYING DROPOUT ONE MUST APPLY SCALING ! proba = 1. / (1 - dropout) for w_full, w_sub, b_full, \ b_sub in zip(L_W_full, L_W_sub, L_B_full, L_B_sub): if w_sub.name != w_full.name: raise Exception("names %s and %s do not match", w_sub.name, w_full.name) if b_sub.name != b_full.name: raise Exception("names %s and %s do not match", b_sub.name, b_full.name) #get the indices : (layer_name, layer_number, _) = analyze_param_name(w_full.name) layer_name = layer_name + str(layer_number) w_split = splits_W[layer_name] b_split = splits_B[layer_name] rough_kind = rough_kinds[layer_name] if rough_kind == "FULLY_CONNECTED": w_sub.set_value(w_full.get_value()[w_split[0]][:, w_split[1]] * proba) b_sub.set_value(b_full.get_value()[b_split] * proba) elif rough_kind == "CONV_FILTER": w_sub.set_value(w_full.get_value()[w_split[1]][:, w_split[0]] * proba) b_sub.set_value(b_full.get_value()[b_split] * proba) elif rough_kind == "BATCH_NORM": w_sub.set_value(w_full.get_value()[w_split[0]] * proba) b_sub.set_value(b_full.get_value()[b_split] * proba) else: raise Exception("unknown type :%s", rough_kind)