示例#1
0
文件: gbt.py 项目: liuyang0711/GBDT
def main(data_filename,stat_filename,max_iter,sample_rate,learn_rate,max_depth,split_points):
    dataset=DataSet(data_filename);
    print "Model parameters configuration:[data_file=%s,stat_file=%s,max_iter=%d,sample_rate=%f,learn_rate=%f,max_depth=%d,split_points=%d]"%(data_filename,stat_filename,max_iter,sample_rate,learn_rate,max_depth,split_points);
    dataset.describe();
    stat_file=open(stat_filename,"w");
    stat_file.write("iteration\taverage loss in train data\tprediction accuracy on test data\taverage loss in test data\n");
    model=Model(max_iter,sample_rate,learn_rate,max_depth,split_points); 
    train_data=sample(dataset.get_instances_idset(),int(dataset.size()*2.0/3.0));
    test_data=set(dataset.get_instances_idset())-set(train_data);
    model.train(dataset,train_data,stat_file,test_data);
    #model.test(dataset,test_data);
    stat_file.close();
示例#2
0
def train(model_type, batch_size, sequence_length, frame_shape):
    model = ACModel(model_type, input_shape = (20, 120, 120, 3))
    data = DataSet(sequence_length, frame_shape)

    checkpoint = ModelCheckpoint(filepath = os.path.join('CheckPoints', (model_type + '-.{epoch:03d}-{val_loss:.3f}.hdf5')), verbose = 1, save_best_only = True)
    tensorBoard = TensorBoard(log_dir = os.path.join('CheckPoints', 'logs', model_type))

    if 'parallel' not in model_type:
        tri_generator = data.generator('train', 'fn', batch_size)
        val_generator = data.generator('test', 'fn', batch_size)
    else:
        tri_generator = data.parallel_generator('train', batch_size)
        val_generator = data.parallel_generator('test', batch_size)

    model.model.fit_generator(generator = tri_generator, 
                              steps_per_epoch = data.size('train') // batch_size, 
                              epochs = epochs,
                              verbose = 1,
                              callbacks = [tensorBoard, checkpoint],
                              validation_data = val_generator, 
                              validation_steps = 4, 
                              workers = 1)