def generator(dataset_path, is_linear, shape): train, test = load_dict_data(dataset_path, is_linear, shape) sys.stdout.write('building model ...') cp = CheckPoint() v, model = cp.load_model() print v if v == -1: raise IOError('model has not been trained.') model.build_generator(-2) def batch_generator(tt, model, batch_size, comment): sample_num = tt['data'].shape[0] batch_num = sample_num / batch_size result = [] for i in range(batch_num): sys.stdout.write('\rgenerate %s of batch %d' % (comment, i)) sys.stdout.flush() start_i = batch_size * i end_i = batch_size * (i+1) batch_result = model.generate(tt['data'][start_i: end_i]) result.append(batch_result) result = np.vstack(result) print '' return result def batch_test(tt, model, batch_size, comment): sample_num = tt['data'].shape[0] batch_num = sample_num / batch_size result = [] for i in range(batch_num): sys.stdout.write('\rtesting %s of batch %d' % (comment, i)) sys.stdout.flush() start_i = batch_size * i end_i = batch_size * (i+1) batch_result = model.testing(tt['data'][start_i: end_i], tt['labels'][start_i: end_i], -1) result.append(batch_result) result = np.mean(result) print '' return result batch_size = 500 result_X_train = batch_generator(train, model, batch_size, 'train') result_X_test = batch_generator(test, model, batch_size, 'test') print batch_test(train, model, batch_size, 'train') print batch_test(test, model, batch_size, 'test') print result_X_train.shape print result_X_test.shape train['data'] = result_X_train test['data'] = result_X_test return train, test
def build_model(): sys.stdout.write('building model ...') cp = CheckPoint() v, model = cp.load_model() print v if v == -1: raise IOError('model has not been trained.') p_y_given_x = model.layers[-1].get_output() y_pred = T.argmax(p_y_given_x, axis=1) if isinstance(model, DropoutModel): predict = theano.function(inputs = [model.x, model.is_train], outputs = y_pred) return predict, True elif isinstance(model, ConvModel): predict = theano.function(inputs = [model.x], outputs = y_pred) return predict, False return None, True