def main(): parser = argparse.ArgumentParser(description="Evaluates a fluid network") parser.add_argument("--trainscript", type=str, required=True, help="The python training script.") parser.add_argument( "--checkpoint_iter", type=int, required=False, help="The checkpoint iteration. The default is the last checkpoint.") parser.add_argument("--frame-skip", type=int, default=5, help="The frame skip. Default is 5.") args = parser.parse_args() global trainscript module_name = os.path.splitext(os.path.basename(args.trainscript))[0] sys.path.append('.') trainscript = importlib.import_module(module_name) # get a list of checkpoints checkpoint_files = glob( os.path.join(trainscript.train_dir, 'checkpoints', 'ckpt-*.index')) all_checkpoints = sorted([ (int(re.match('.*ckpt-(\d+)\.index', x).group(1)), os.path.splitext(x)[0]) for x in checkpoint_files ]) # select the checkpoint if args.checkpoint_iter is not None: checkpoint = dict(all_checkpoints)[args.checkpoint_iter] else: checkpoint = all_checkpoints[-1] output_path = args.trainscript + '_eval_{}.json'.format(checkpoint[0]) if os.path.isfile(output_path): print('Printing previously computed results for :', checkpoint) fluid_errors = FluidErrors() fluid_errors.load(output_path) else: print('evaluating :', checkpoint) fluid_errors = FluidErrors() eval_checkpoint(checkpoint[1], trainscript.val_files, fluid_errors, args) fluid_errors.save(output_path) print_errors(fluid_errors) return 0
def main(): parser = argparse.ArgumentParser( description="Evaluates a fluid network", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--trainscript", type=str, required=True, help="The python training script.") parser.add_argument( "--checkpoint_iter", type=int, required=False, help="The checkpoint iteration. The default is the last checkpoint.") parser.add_argument( "--weights", type=str, required=False, help="If set uses the specified weights file instead of a checkpoint.") parser.add_argument("--frame-skip", type=int, default=5, help="The frame skip. Default is 5.") parser.add_argument("--device", type=str, default="cuda", help="The device to use. Applies only for torch.") args = parser.parse_args() global trainscript module_name = os.path.splitext(os.path.basename(args.trainscript))[0] sys.path.append('.') trainscript = importlib.import_module(module_name) if args.weights is not None: print('evaluating :', args.weights) output_path = args.weights + '_eval.json' if os.path.isfile(output_path): print('Printing previously computed results for :', args.weights) fluid_errors = FluidErrors() fluid_errors.load(output_path) else: fluid_errors = FluidErrors() eval_checkpoint(args.weights, trainscript.val_files, fluid_errors, args) fluid_errors.save(output_path) else: # get a list of checkpoints # tensorflow checkpoints checkpoint_files = glob( os.path.join(trainscript.train_dir, 'checkpoints', 'ckpt-*.index')) # torch checkpoints checkpoint_files.extend( glob(os.path.join(trainscript.train_dir, 'checkpoints', 'ckpt-*.pt'))) all_checkpoints = sorted([ (int(re.match('.*ckpt-(\d+)\.(pt|index)', x).group(1)), x) for x in checkpoint_files ]) # select the checkpoint if args.checkpoint_iter is not None: checkpoint = dict(all_checkpoints)[args.checkpoint_iter] else: checkpoint = all_checkpoints[-1] output_path = args.trainscript + '_eval_{}.json'.format(checkpoint[0]) if os.path.isfile(output_path): print('Printing previously computed results for :', checkpoint) fluid_errors = FluidErrors() fluid_errors.load(output_path) else: print('evaluating :', checkpoint) fluid_errors = FluidErrors() eval_checkpoint(checkpoint[1], trainscript.val_files, fluid_errors, args) fluid_errors.save(output_path) print_errors(fluid_errors) return 0