Exemplo n.º 1
	def load(self, 
		ydim, n_words,
		dim_proj = 128,
		encoder = 'lstm',
		use_dropout = True,
		fname_model = FNAME_MODEL, 
		model_options = locals().copy()
		params = lstmtool.init_params(model_options)
		lstmtool.load_params(fname_model, params)	
		tparams = lstmtool.init_tparams(params)

		use_noise, x, mask, y, f_pred_prob, f_pred, cost = lstmtool.build_model(tparams, model_options)

		self.f_pred = f_pred
		self.f_pred_prob = f_pred_prob
Exemplo n.º 2
Arquivo: lstm.py Projeto: liangxh/idu
	def load(self, 
		encoder = 'lstm',
		#ydim, n_words,
		#dim_proj = 128,
		#use_dropout = True,	

		model_options = locals().copy()

		train_params = cPickle.load(open('%s.pkl'%(fname_model), 'r')) # why -1??

		params = lstmtool.init_params(model_options, None)
		lstmtool.load_params(fname_model, params)
		tparams = lstmtool.init_tparams(params)

		use_noise, x, mask, y, f_pred_prob, f_pred, cost = lstmtool.build_model(tparams, model_options)

		self.f_pred = f_pred
		self.f_pred_prob = f_pred_prob
Exemplo n.º 3
Arquivo: lstm.py Projeto: liangxh/idu
	def train(self,
		dataset, Wemb, ydim,
		# model params		
		use_dropout = True,
		reload_model = False,
		fname_model = None,
		# training params
		validFreq = 1000,
		saveFreq = 1000,
		patience = 10,
		max_epochs = 5000,
		decay_c = 0.,
		lrate = 0.0001,
		batch_size = 16,
		valid_batch_size = 64,
		optimizer = lstmtool.adadelta,
		noise_std = 0., 

		# debug params
		dispFreq = 10,
		train, valid, test = dataset

		# building model
		logger.info('building model...')

		dim_proj = Wemb.shape[1] # numpy.ndarray expected

		model_options = locals().copy()
		model_options['dim_proj'] = dim_proj
		model_options['encoder'] = 'lstm'

		model_config = {
		cPickle.dump(model_config, open('%s.pkl'%(fname_model), 'wb'), -1) # why -1??

		params = lstmtool.init_params(model_options, Wemb)

		if reload_model:
			if os.path.exists(fname_model):
				lstmtool.load_params(fname_model, params)
				logger.warning('model %s not found'%(fname_model))
				return None
		elif Wemb is None:
			logger.warning('Wemb is missing for training LSTM')
			return None
		tparams = lstmtool.init_tparams(params)
		use_noise, x, mask, y, f_pred_prob, f_pred, cost = lstmtool.build_model(tparams, model_options)

		# preparing functions for training
		logger.info('preparing functions')

		if decay_c > 0.:
			decay_c = theano.shared(lstmtool.numpy_floatX(decay_c), name='decay_c')
			weight_decay = 0.
			weight_decay += (tparams['U'] ** 2).sum()
			weight_decay *= decay_c
			cost += weight_decay
		f_cost = theano.function([x, mask, y], cost, name = 'f_cost')
		grads = theano.tensor.grad(cost, wrt = tparams.values())
		f_grad = theano.function([x, mask, y], grads, name = 'f_grad')

		lr = theano.tensor.scalar(name = 'lr')
		f_grad_shared, f_update = optimizer(lr, tparams, grads, x, mask, y, cost)

		kf_valid = lstmtool.get_minibatches_idx(len(valid[0]), valid_batch_size)
		kf_test = lstmtool.get_minibatches_idx(len(test[0]), valid_batch_size)

		if validFreq == None:
			validFreq = len(train[0]) / batch_size
		if saveFreq == None:
			saveFreq = len(train[0]) / batch_size
		history_errs = []
		best_p = None
		bad_count = 0

		uidx = 0       # number of update done
		estop = False  # early stop

		# training
		logger.info('start training...')

		start_time = time.time()

			for eidx in xrange(max_epochs):
				n_samples = 0
				kf = lstmtool.get_minibatches_idx(len(train[0]), batch_size, shuffle = True)
				for _, train_index in kf:
					uidx += 1

					x = [train[0][t] for t in train_index]
					y = [train[1][t] for t in train_index]

					x, mask = self.prepare_x(x)
					n_samples += x.shape[1]

					cost = f_grad_shared(x, mask, y)
					if np.isnan(cost) or np.isinf(cost):
						NaN of Inf encountered
						logger.warning('NaN detected')
						return 1., 1., 1.
					if np.mod(uidx, dispFreq) == 0:
						display progress at $dispFreq
						logger.info('Epoch %d Update %d Cost %f'%(eidx, uidx, cost))

					if np.mod(uidx, saveFreq) == 0:
						save new model to file at $saveFreq
						logger.info('Model update')
						if best_p is not None:
							params = best_p
							params = lstmtool.unzip(tparams)
						np.savez(fname_model, history_errs = history_errs, **params)

					if np.mod(uidx, validFreq) == 0:
						check prediction error at %validFreq
						logger.info('Validation ....')

						# not necessary	
						train_err = lstmtool.pred_error(f_pred, self.prepare_data, train, kf)
						valid_err = lstmtool.pred_error(f_pred, self.prepare_data, valid, kf_valid)
						test_err = lstmtool.pred_error(f_pred, self.prepare_data, test, kf_test)

						history_errs.append([valid_err, test_err])
						if (uidx == 0 or valid_err <= np.array(history_errs)[:, 0].min()):
							best_p = lstmtool.unzip(tparams)
							bad_count = 0
						logger.info('prediction error: train %f valid %f test %f'%(
								train_err, valid_err, test_err)
						if (len(history_errs) > patience and
							valid_err >= np.array(history_errs)[:-patience, 0].min()):
							bad_count += 1
							if bad_count > patience:
								logger.info('Early stop!')
								estop = True

				logger.info('%d samples seen'%(n_samples))
				if estop:
		except KeyboardInterrupt:
			print logger.debug('training interrupted by user')

		end_time = time.time()

		if best_p is not None:
			lstmtool.zipp(best_p, tparams)
			best_p = lstmtool.unzip(tparams)

		kf_train = lstmtool.get_minibatches_idx(len(train[0]), batch_size)
		train_err = lstmtool.pred_error(f_pred, self.prepare_data, train, kf_train)
		valid_err = lstmtool.pred_error(f_pred, self.prepare_data, valid, kf_valid)
		test_err = lstmtool.pred_error(f_pred, self.prepare_data, test, kf_test)
		logger.info('prediction error: train %f valid %f test %f'%(
				train_err, valid_err, test_err)
			train_err = train_err,
			valid_err = valid_err,
			test_error = test_err,
			history_errs = history_errs, **best_p

		logger.info('totally %d epoches in %.1f sec'%(eidx + 1, end_time - start_time))

		self.f_pred_prob = f_pred_prob
		self.f_pred = f_pred
		self.tparams = tparams

		return train_err, valid_err, test_err, end_time - start_time