예제 #1
0
def find_checkpoint(checkpoint_str):
    # find checkpoint
    steps = 0
    if checkpoint_str is not None:
        if ':' in checkpoint_str:
            prefix, steps = checkpoint_str.split(':')
        else:
            prefix = checkpoint_str
            steps = None
        log_file, run_id = path.find_log(prefix)
        if steps is None:
            checkpoint, steps = path.find_checkpoints(run_id)[-1]
        else:
            checkpoints = path.find_checkpoints(run_id)
            try:
                checkpoint, steps = next(
                    filter(lambda t: t[1] == steps, checkpoints))
            except StopIteration:
                print('The steps not found in checkpoints', steps, checkpoints)
                sys.stdout.flush()
                raise StopIteration
        steps = int(steps)
        if args.clear_steps:
            steps = 0
        else:
            _, exp_info = path.read_log(log_file)
            exp_info = exp_info[-1]
            for k in args.__dict__:
                if k in exp_info and k in ('tag', ):
                    setattr(args, k, eval(exp_info[k]))
                    print('{}={}, '.format(k, exp_info[k]), end='')
            print()
        sys.stdout.flush()
    return checkpoint, steps
예제 #2
0
파일: main.py 프로젝트: yamaru12345/DF-VO2
	log_file, run_id = path.find_log(prefix)	
	if steps is None:
		checkpoint, steps = path.find_checkpoints(run_id)[-1]
	else:
		checkpoints = path.find_checkpoints(run_id)
		try:
			checkpoint, steps = next(filter(lambda t : t[1] == steps, checkpoints))
		except StopIteration:
			print('The steps not found in checkpoints', steps, checkpoints)
			sys.stdout.flush()
			raise StopIteration
	steps = int(steps)
	if args.clear_steps:
		steps = 0
	else:
		_, exp_info = path.read_log(log_file)
		exp_info = exp_info[-1]
		for k in args.__dict__:
			if k in exp_info and k in ('tag',):
				setattr(args, k, eval(exp_info[k]))
				print('{}={}, '.format(k, exp_info[k]), end='')
		print()
	sys.stdout.flush()
# generate id
if args.checkpoint is None or args.clear_steps:
	uid = (socket.gethostname() + logger.FileLog._localtime().strftime('%b%d-%H%M') + args.gpu_device)
	tag = hashlib.sha224(uid.encode()).hexdigest()[:3] 
	run_id = tag + logger.FileLog._localtime().strftime('%b%d-%H%M')

# initiate
from network import get_pipeline