from grid_control.gc_exceptions import RuntimeError
from grid_control.utils.webservice import readJSON
from grid_control_cms.provider_sitedb import SiteDB

def lfn2pfn(node, lfn):
	return readJSON('https://cmsweb.cern.ch/phedex/datasvc/json/prod/lfn2pfn',
		{'node': node, 'protocol': 'srmv2', 'lfn': lfn})['phedex']['mapping'][0]['pfn']


parser = optparse.OptionParser()
parser.add_option('-s', '--SE', dest='SE', default=None, help='Resolve LFN on CMS SE into PFN')
parser.add_option('', '--lfn', dest='lfn', default='/store/user/<hypernews name>', help='Name of default LFN')
parser.add_option('', '--se-prot', dest='seprot', default='srmv2', help='Name of default SE protocol')
(opts, args) = parseOptions(parser)

if opts.SE:
	if '<hypernews name>' in opts.lfn:
		token = AccessToken.getInstance('VomsProxy', getConfig(), None)
		site_db = SiteDB()
		hnName = site_db.dn_to_username(dn=token.getFQUsername())
		if not hnName:
			raise RuntimeError('Unable to map grid certificate to hypernews name!')
		opts.lfn = opts.lfn.replace('<hypernews name>', hnName)

	tmp = readJSON('https://cmsweb.cern.ch/phedex/datasvc/json/prod/lfn2pfn',
		{'node': opts.SE, 'protocol': opts.seprot, 'lfn': opts.lfn})['phedex']['mapping']
	for entry in tmp:
		if len(tmp) > 1:
			print entry['node'],
		print entry['pfn']
def realmain(opts, args):
	config = gcSupport.getConfig(configDict = {'access': {'ignore warnings': 'True'}})
	token = AccessToken.getInstance(opts.token, config, 'access', OSLayer.create(config))
	(workDir, config, jobDB) = gcSupport.initGC(args)
	jobList = jobDB.getJobs(ClassSelector(JobClass.SUCCESS))

	# Create SE output dir
	if not opts.output:
		opts.output = os.path.join(workDir, 'se_output')
	if '://' not in opts.output:
		opts.output = 'file:///%s' % os.path.abspath(opts.output)

	infos = {}
	def incInfo(x):
		infos[x] = infos.get(x, 0) + 1

	def processSingleJob(jobNum, output):
		output.init(jobNum)
		job = jobDB.get(jobNum)
		# Only run over finished and not yet downloaded jobs
		if job.state != Job.SUCCESS:
			output.error('Job has not yet finished successfully!')
			return incInfo('Processing')
		if job.get('download') == 'True' and not opts.markIgnoreDL:
			if not opts.threads:
				output.error('All files already downloaded!')
			return incInfo('Downloaded')
		retry = int(job.get('download attempt', 0))
		failJob = False

		if not token.canSubmit(20*60, True):
			sys.stderr.write('Please renew access token!')
			sys.exit(os.EX_UNAVAILABLE)

		# Read the file hash entries from job info file
		files = FileInfoProcessor().process(os.path.join(workDir, 'output', 'job_%d' % jobNum))
		if files:
			files = map(lambda fi: (fi[FileInfoProcessor.Hash], fi[FileInfoProcessor.NameLocal],
				fi[FileInfoProcessor.NameDest], fi[FileInfoProcessor.Path]), files)
		output.files(files)
		if not files:
			if opts.markEmptyFailed:
				failJob = True
			else:
				return incInfo('Job without output files')

		for (fileIdx, fileInfo) in enumerate(files):
			(hash, name_local, name_dest, pathSE) = fileInfo
			output.file(fileIdx)

			# Copy files to local folder
			outFilePath = os.path.join(opts.output, name_dest)
			if opts.selectSE:
				if not (True in map(lambda s: s in pathSE, opts.selectSE)):
					output.error('skip file because it is not located on selected SE!')
					return
			if opts.skipExisting and (storage.se_exists(outFilePath) == 0):
				output.error('skip file as it already exists!')
				return
			if storage.se_exists(os.path.dirname(outFilePath)).wait() != 0:
				storage.se_mkdir(os.path.dirname(outFilePath)).wait()

			checkPath = 'file:///tmp/dlfs.%s' % name_dest
			if 'file://' in outFilePath:
				checkPath = outFilePath

			def monitorFile(path, lock, abort):
				path = path.replace('file://', '')
				(csize, osize, stime, otime, lttime) = (0, 0, time.time(), time.time(), time.time())
				while not lock.acquire(False): # Loop until monitor lock is available
					if csize != osize:
						lttime = time.time()
					if time.time() - lttime > 5*60: # No size change in the last 5min!
						output.error('Transfer timeout!')
						abort.acquire()
						break
					if os.path.exists(path):
						csize = os.path.getsize(path)
						output.file(fileIdx, csize, osize, stime, otime)
						(osize, otime) = (csize, time.time())
					else:
						stime = time.time()
					time.sleep(0.1)
				lock.release()

			copyAbortLock = threading.Lock()
			monitorLock = threading.Lock()
			monitorLock.acquire()
			monitor = utils.gcStartThread('Download monitor %s' % jobNum,
				monitorFile, checkPath, monitorLock, copyAbortLock)
			result = -1
			procCP = storage.se_copy(os.path.join(pathSE, name_dest), outFilePath, tmp = checkPath)
			while True:
				if not copyAbortLock.acquire(False):
					monitor.join()
					break
				copyAbortLock.release()
				result = procCP.poll()
				if result != -1:
					monitorLock.release()
					monitor.join()
					break
				time.sleep(0.02)

			if result != 0:
				output.error('Unable to copy file from SE!')
				output.error(procCP.getMessage())
				failJob = True
				break

			# Verify => compute md5hash
			if opts.verify:
				try:
					hashLocal = md5sum(checkPath.replace('file://', ''))
					if not ('file://' in outFilePath):
						dlfs_rm('file://%s' % checkPath, 'SE file')
				except KeyboardInterrupt:
					raise
				except Exception:
					hashLocal = None
				output.hash(fileIdx, hashLocal)
				if hash != hashLocal:
					failJob = True
			else:
				output.hash(fileIdx)

		# Ignore the first opts.retry number of failed jobs
		if failJob and opts.retry and (retry < opts.retry):
			output.error('Download attempt #%d failed!' % (retry + 1))
			job.set('download attempt', str(retry + 1))
			jobDB.commit(jobNum, job)
			return incInfo('Download attempts')

		for (fileIdx, fileInfo) in enumerate(files):
			(hash, name_local, name_dest, pathSE) = fileInfo
			# Remove downloaded files in case of failure
			if (failJob and opts.rmLocalFail) or (not failJob and opts.rmLocalOK):
				output.status(fileIdx, 'Deleting file %s from local...' % name_dest)
				outFilePath = os.path.join(opts.output, name_dest)
				if storage.se_exists(outFilePath).wait() == 0:
					dlfs_rm(outFilePath, 'local file')
			# Remove SE files in case of failure
			if (failJob and opts.rmSEFail)    or (not failJob and opts.rmSEOK):
				output.status(fileIdx, 'Deleting file %s...' % name_dest)
				dlfs_rm(os.path.join(pathSE, name_dest), 'SE file')
			output.status(fileIdx, None)

		if failJob:
			incInfo('Failed downloads')
			if opts.markFailed:
				# Mark job as failed to trigger resubmission
				job.state = Job.FAILED
		else:
			incInfo('Successful download')
			if opts.markDL:
				# Mark as downloaded
				job.set('download', 'True')

		# Save new job status infos
		jobDB.commit(jobNum, job)
		output.finish()
		time.sleep(float(opts.slowdown))

	if opts.shuffle:
		random.shuffle(jobList)
	else:
		jobList.sort()

	if opts.threads:
		from grid_control_gui import ansi
		errorOutput = []
		class ThreadDisplay:
			def __init__(self):
				self.output = []
			def init(self, jobNum):
				self.jobNum = jobNum
				self.output = ['Job %5d' % jobNum, '']
			def infoline(self, fileIdx, msg = ''):
				return 'Job %5d [%i/%i] %s %s' % (self.jobNum, fileIdx + 1, len(self.files), self.files[fileIdx][2], msg)
			def files(self, files):
				(self.files, self.output, self.tr) = (files, self.output[1:], ['']*len(files))
				for x in range(len(files)):
					self.output.insert(2*x, self.infoline(x))
					self.output.insert(2*x+1, '')
			def file(self, idx, csize = None, osize = None, stime = None, otime = None):
				(hash, name_local, name_dest, pathSE) = self.files[idx]
				if otime:
					trfun = lambda sref, tref: gcSupport.prettySize(((csize - sref) / max(1, time.time() - tref)))
					self.tr[idx] = '%7s avg. - %7s/s inst.' % (gcSupport.prettySize(csize), trfun(0, stime))
					self.output[2*idx] = self.infoline(idx, '(%s - %7s/s)' % (self.tr[idx], trfun(osize, otime)))
			def hash(self, idx, hashLocal = None):
				(hash, name_local, name_dest, pathSE) = self.files[idx]
				if hashLocal:
					if hash == hashLocal:
						result = ansi.Console.fmt('MATCH', [ansi.Console.COLOR_GREEN])
					else:
						result = ansi.Console.fmt('FAIL', [ansi.Console.COLOR_RED])
					msg = '(R:%s L:%s) => %s' % (hash, hashLocal, result)
				else:
					msg = ''
				self.output[2*idx] = self.infoline(idx, '(%s)' % self.tr[idx])
				self.output[2*idx+1] = msg
				print self, repr(msg)
			def error(self, msg):
				errorOutput.append(msg)
			def write(self, msg):
				self.output.append(msg)
			def status(self, idx, msg):
				if msg:
					self.output[2*idx] = self.infoline(idx, '(%s)' % self.tr[idx]) + ' ' + msg
				else:
					self.output[2*idx] = self.infoline(idx, '(%s)' % self.tr[idx])
			def finish(self):
#				self.output.append(str(self.jobNum) + 'FINISHED')
				pass

		(active, todo) = ([], list(jobList))
		todo.reverse()
		screen = ansi.Console()
		screen.move(0, 0)
		screen.savePos()
		while True:
			screen.erase()
			screen.loadPos()
			active = filter(lambda (t, d): t.isAlive(), active)
			while len(active) < opts.threads and len(todo):
				display = ThreadDisplay()
				active.append((utils.gcStartThread('Download %s' % todo[-1],
					processSingleJob, todo.pop(), display), display))
			for (t, d) in active:
				sys.stdout.write(str.join('\n', d.output))
			sys.stdout.write(str.join('\n', ['=' * 50] + errorOutput))
			sys.stdout.flush()
			if len(active) == 0:
				break
			time.sleep(0.01)
	else:
		class DefaultDisplay:
			def init(self, jobNum):
				sys.stdout.write('Job %d: ' % jobNum)
			def files(self, files):
				self.files = files
				sys.stdout.write('The job wrote %d file%s to the SE\n' % (len(files), ('s', '')[len(files) == 1]))
			def file(self, idx, csize = None, osize = None, stime = None, otime = None):
				(hash, name_local, name_dest, pathSE) = self.files[idx]
				if otime:
					tr = lambda sref, tref: gcSupport.prettySize(((csize - sref) / max(1, time.time() - tref)))
					tmp = name_dest
					if opts.showHost:
						tmp += ' [%s]' % pathSE.split('//')[-1].split('/')[0].split(':')[0]
					self.write('\r\t%s (%7s - %7s/s avg. - %7s/s inst.)' % (tmp,
						gcSupport.prettySize(csize), tr(0, stime), tr(osize, otime)))
					sys.stdout.flush()
				else:
					self.write('\t%s' % name_dest)
					sys.stdout.flush()
			def hash(self, idx, hashLocal = None):
				(hash, name_local, name_dest, pathSE) = self.files[idx]
				self.write(' => %s\n' % ('\33[0;91mFAIL\33[0m', '\33[0;92mMATCH\33[0m')[hash == hashLocal])
				self.write('\t\tRemote site: %s\n' % hash)
				self.write('\t\t Local site: %s\n' % hashLocal)
			def error(self, msg):
				sys.stdout.write('\nJob %d: %s' % (jobNum, msg.strip()))
			def status(self, idx, msg):
				if msg:
					self.write('\t' + msg + '\r')
				else:
					self.write(' ' * len('\tDeleting file %s from SE...\r' % self.files[idx][2]) + '\r')
			def write(self, msg):
				sys.stdout.write(msg)
			def finish(self):
				sys.stdout.write('\n')

		for jobNum in jobList:
			processSingleJob(jobNum, DefaultDisplay())

	# Print overview
	if infos:
		print '\nStatus overview:'
		for (state, num) in infos.items():
			if num > 0:
				print '\t%20s: [%d/%d]' % (state, num, len(jobList))
		print

	if ('Downloaded' in infos) and (infos['Downloaded'] == len(jobDB)):
		return os.EX_OK
	return os.EX_NOINPUT