def genModel(name,out_dim, depth, width, dense_activation="relu", dropout = 0.0,sphereCoords=True): inputs = [] mergelist = [] for i, profile in enumerate(object_profiles): # print(o) inp = a = Input(shape=(profile.max_size, vecsize), name="input_"+str(i)) inputs.append(inp) if(name == 'lorentz'): b1 = Lorentz(sphereCoords=sphereCoords, name="lorentz_"+str(i))(a) else: b1 = Slice('[:,0:4]',name='slice_1_'+str(i))(a) b1 = Flatten(name="flatten1_"+str(i))(b1) b2 = Slice('[:,4:9]',name='slice_2_'+str(i))(a) b2 = Flatten(name="flatten_2_"+str(i))(b2) # b2 = Dense(10, activation='relu')(b2) mergelist.append(b1) mergelist.append(b2) # print(mergelist) a = merge(mergelist,mode='concat', name="merge") # a = Flatten()(a) for i in range(depth): a = Dense(width, activation=dense_activation, name="dense_"+str(i))(a) if(dropout > 0.0): a = Dropout(dropout, name="dropout_"+str(i))(a) dense_out = Dense(out_dim, activation='softmax', name='main_output')(a) model = Model(input=inputs, output=dense_out, name=name) return model
def genModel(name, out_dim, depth, width, vecsize, object_profiles, dense_activation="relu", output_activation='softmax', dropout=0.0, sphereCoords=True, weight_output=False): inputs = [] mergelist = [] for i, profile in enumerate(object_profiles): # print(o) inp = a = Input(shape=(profile.max_size, vecsize), name="input_" + str(i)) inputs.append(inp) if (name == 'lorentz'): b1 = Lorentz(sphereCoords=sphereCoords, weight_output=weight_output, name="lorentz_" + str(i))(a) b1 = Flatten(name="flatten1_" + str(i))(b1) elif (name == 'lorentz_vsum'): b1 = Lorentz(sphereCoords=sphereCoords, weight_output=weight_output, name="lorentz_" + str(i), sum_input=True)(a) b1 = Flatten(name="flatten1_" + str(i))(b1) elif (name == 'control_dense'): b1 = Slice('[:,0:4]', name='slice_1_' + str(i))(a) b1 = Flatten(name="4_flatten_" + str(i))(b1) b1 = Dense(4 * profile.max_size, activation='linear', name='4_dense_' + str(i))(b1) elif (name == 'control'): b1 = Slice('[:,0:4]', name='slice_1_' + str(i))(a) b1 = Flatten(name="flatten1_" + str(i))(b1) else: raise ValueError("Model name %r not understood." % name) if ("_vsum" in name): b2 = a else: b2 = Slice('[:,4:]', name='slice_2_' + str(i))(a) b2 = Flatten(name="flatten_2_" + str(i))(b2) # b2 = Dense(10, activation='relu')(b2) mergelist.append(b1) mergelist.append(b2) # print(mergelist) a = merge(mergelist, mode='concat', name="merge") # a = Flatten()(a) for i in range(depth): a = Dense(width, activation=dense_activation, name="dense_" + str(i))(a) if (dropout > 0.0): a = Dropout(dropout, name="dropout_" + str(i))(a) dense_out = Dense(out_dim, activation=output_activation, name='main_output')(a) model = Model(input=inputs, output=dense_out, name=name) return model