コード例 #1
0
ファイル: resume.py プロジェクト: thetianshuhuang/l2o
"""Resume Training.

Run with
```
python resume.py directory --vgpu=1
```
"""

import os
import sys

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf

import l2o
from config import ArgParser
from gpu_setup import create_distribute


args = ArgParser(sys.argv[2:])
vgpus = args.pop_get("--vgpu", default=1, dtype=int)
distribute = create_distribute(vgpus=vgpus)

with distribute.scope():
    strategy = l2o.strategy.build_from_config(sys.argv[1])
    strategy.train()
コード例 #2
0
ファイル: baseline.py プロジェクト: thetianshuhuang/l2o
import os
import sys
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf

import l2o
from config import ArgParser, get_eval_problem
from gpu_setup import create_distribute

args = ArgParser(sys.argv[1:])
vgpus = args.pop_get("--vgpu", default=1, dtype=int)
cpu = args.pop_get("--cpu", default=False, dtype=bool)
gpus = args.pop_get("--gpus", default=None)
use_keras = args.pop_get("--keras", default=True, dtype=bool)
distribute = create_distribute(vgpus=vgpus, do_cpu=cpu, gpus=gpus)

problem = args.pop_get("--problem", "conv_train")

target = args.pop_get("--optimizer", "adam")
target_cfg = {
    "adam": {
        "class_name": "Adam",
        "config": {
            "learning_rate": 0.005,
            "beta_1": 0.9,
            "beta_2": 0.999
        }
    },
    "rmsprop": {
        "class_name": "RMSProp",
コード例 #3
0
args = ArgParser(sys.argv[1:])

# Finally ready to import tensorflow
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import l2o
from gpu_setup import create_distribute

# Directory
directory = args.pop_get("--directory", default="weights")

# Distribute
vgpus = int(args.pop_get("--vgpu", default=1))
memory_limit = int(args.pop_get("--vram", default=12000))
gpus = args.pop_get("--gpus", default=None)
distribute = create_distribute(
    vgpus=vgpus, memory_limit=memory_limit, gpus=gpus)

# Pick up flags first
initialize_only = args.pop_check("--initialize")

# Default params
strategy = args.pop_get("--strategy", "repeat")
policy = args.pop_get("--policy", "rnnprop")
default = get_default(strategy=strategy, policy=policy)

# Build overrides
presets = args.pop_get("--presets", "")
overrides = []
if presets != "":
    for p in presets.split(','):
        overrides += get_preset(p)