def parse_train_config(config=None):
    config = {} if not config else config
    c = AttrDict()

    c.DATASET_ROOT = config.get("DATASET_ROOT", DATASET_ROOT)
    c.JSON_PATH = config.get("JSON_PATH", "train.json")
    c.BATCH_SIZE = config.get("BATCH_SIZE", BATCH_SIZE)
    c.IMAGE_SIZE = config.get("IMAGE_SIZE", IMAGE_SIZE)
    c.WORKERS = config.get("WORKERS", WORKERS)
    c.PIN_MEMORY = config.get("PIN_MEMORY", PIN_MEMORY)
    c.SHUFFLE = config.get("SHUFFLE", True)

    c.LEARNING_RATE = config.get("LEARNING_RATE", LEARNING_RATE)
    c.MOMENTUM = config.get("MOMENTUM", MOMENTUM)
    c.DAMPENING = config.get("DAMPENING", DAMPENING)
    c.BETAS = config.get("BETAS", BETAS)
    c.EPS = config.get("EPS", EPS)
    c.WEIGHT_DECAY = config.get("WEIGHT_DECAY", WEIGHT_DECAY)

    c.MILESTONES = config.get("MILESTONES", MILESTONES)
    c.GAMMA = config.get("GAMMA", GAMMA)

    c.NUM_EPOCHS = config.get("NUM_EPOCHS", NUM_EPOCHS)
    c.TEST = config.get("TEST", TEST)
    c.OUT_PATH = config.get("OUT_PATH", OUT_PATH)
    c.LOAD_MODEL = config.get("LOAD_MODEL", LOAD_MODEL)
    c.SAVE_MODEL = config.get("SAVE_MODEL", SAVE_MODEL)
    c.CHECKPOINT_FILE = config.get("CHECKPOINT_FILE", CHECKPOINT_FILE)

    return c
def parse_detect_config(config=None):
    config = {} if not config else config
    c = AttrDict()

    c.JSON = config.get("JSON", JSON)
    c.IMAGE_SIZE = config.get("IMAGE_SIZE", IMAGE_SIZE)
    c.CHECKPOINT_FILE = config.get("CHECKPOINT_FILE", CHECKPOINT_FILE)

    return c
def parse_test_config(config=None):
    config = {} if not config else config
    c = AttrDict()

    c.DATASET_ROOT = config.get("DATASET_ROOT", DATASET_ROOT)
    c.JSON_PATH = config.get("JSON_PATH", "test.json")
    c.BATCH_SIZE = config.get("BATCH_SIZE", BATCH_SIZE)
    c.IMAGE_SIZE = config.get("IMAGE_SIZE", IMAGE_SIZE)
    c.WORKERS = config.get("WORKERS", WORKERS)
    c.PIN_MEMORY = config.get("PIN_MEMORY", PIN_MEMORY)
    c.SHUFFLE = config.get("SHUFFLE", False)

    c.OUT_PATH = config.get("OUT_PATH", OUT_PATH)
    c.LOAD_MODEL = config.get("LOAD_MODEL", True)
    c.CHECKPOINT_FILE = config.get("CHECKPOINT_FILE", CHECKPOINT_FILE)

    return c
Beispiel #4
0
def get_params():
    checkpoint_dir = '/Users/Nolsigan/PycharmProjects/rlntm-tensorflow/checkpoints'
    max_length = 6
    rnn_cell = rnn.BasicLSTMCell
    rnn_hidden = 128
    learning_rate = 0.003
    optimizer = tf.train.AdamOptimizer()
    gradient_clipping = 5
    batch_size = 100
    epochs = 30
    epoch_size = 100
    num_symbols = 10
    dup_factor = 2
    mem_dim = 128
    mem_move_table = [-1, 0, 1]
    in_move_table = [-1, 0, 1]
    out_move_table = [0, 1]
    return AttrDict(**locals())
# Config
##############################################################################


from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import os
import re
import torch

from attr_dict import AttrDict

__C = AttrDict()
cfg = __C
__C.GLOBAL_RANK = 0
__C.EPOCH = 0
# Absolute path to a location to keep some large files, not in this dir.
__C.ASSETS_PATH = '/home/dcg-adlr-atao-data.cosmos277/assets'

# Use class weighted loss per batch to increase loss for low pixel count classes per batch
__C.BATCH_WEIGHTING = False

# Border Relaxation Count
__C.BORDER_WINDOW = 1
# Number of epoch to use before turn off border restriction
__C.REDUCE_BORDER_EPOCH = -1
# Comma Seperated List of class id to relax
__C.STRICTBORDERCLASS = None