예제 #1
0
def get_data(manifest, manifest_root, batch_size, subset_pct, rng_seed):
    '''
    Loads training and validation set using aeon loader

    args(object): Contains function arguments
    manifest(list): Manifest files for traning and validaions
    manifest_root(string): Root directory of manifest file
    batch_size(int): Mini batch size
    subset_pct(float): Subset percentage of the data (0-100)
    rng_seed(int): Seed for random number generator
    '''

    assert 'train' in manifest[1], "Missing train manifest"
    assert 'test' in manifest[0], "Missing validation manifest"

    train_set = make_train_loader(manifest[1], manifest_root, batch_size,
                                  subset_pct, rng_seed)
    valid_set = make_validation_loader(manifest[0], manifest_root, batch_size,
                                       subset_pct)

    return train_set, valid_set
예제 #2
0
parser = NeonArgparser(__doc__, default_config_files=config_files)
parser.add_argument('--subset_pct',
                    type=float,
                    default=100,
                    help='subset of training dataset to use (percentage)')
args = parser.parse_args()

random_seed = 0 if args.rng_seed is None else args.rng_seed
model, cost = create_network()

# setup data provider
assert 'train' in args.manifest, "Missing train manifest"
assert 'test' in args.manifest, "Missing validation manifest"

train = make_train_loader(args.manifest['train'], args.manifest_root, model.be,
                          args.subset_pct, random_seed)
valid = make_test_loader(args.manifest['test'], args.manifest_root, model.be,
                         args.subset_pct)

# setup callbacks
callbacks = Callbacks(model, eval_set=valid, **args.callback_args)

# gradient descent with momentum, weight decay, and learning rate decay schedule
learning_rate_sched = Schedule(list(range(6, args.epochs, 6)), 0.1)
opt_gdm = GradientDescentMomentum(0.003,
                                  0.9,
                                  wdecay=0.005,
                                  schedule=learning_rate_sched)
opt_biases = GradientDescentMomentum(0.006, 0.9, schedule=learning_rate_sched)
opt = MultiOptimizer({'default': opt_gdm, 'Bias': opt_biases})
예제 #3
0
파일: train.py 프로젝트: rlugojr/neon
parser = NeonArgparser(__doc__, default_config_files=config_files)
parser.add_argument('--depth', type=int, default=2,
                    help='depth of each stage (network depth will be 9n+2)')
parser.add_argument('--subset_pct', type=float, default=100,
                    help='subset of training dataset to use (percentage)')
args = parser.parse_args()
random_seed = args.rng_seed if args.rng_seed else 0

# Check that the proper manifest sets have been supplied
assert 'train' in args.manifest, "Missing train manifest"
assert 'val' in args.manifest, "Missing validation manifest"

model, cost = create_network(args.depth)

# setup data provider
train = make_train_loader(args.manifest['train'], args.manifest_root, model.be, args.subset_pct,
                          random_seed)
test = make_validation_loader(args.manifest['val'], args.manifest_root, model.be, args.subset_pct)

# tune batch norm parameters on subset of train set with no augmentations
tune_set = make_tuning_loader(args.manifest['train'], args.manifest_root, model.be)

# configure callbacks
callbacks = Callbacks(model, eval_set=test, metric=Misclassification(), **args.callback_args)
callbacks.add_callback(BatchNormTuneCallback(tune_set), insert_pos=0)

# begin training
opt = GradientDescentMomentum(0.1, 0.9, wdecay=0.0001, schedule=Schedule([82, 124], 0.1))
model.fit(train, optimizer=opt, num_epochs=args.epochs, cost=cost, callbacks=callbacks)
예제 #4
0
subm_config = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                           'whale_subm.cfg')
config_files = [subm_config] if os.path.exists(subm_config) else []
parser = NeonArgparser(__doc__, default_config_files=config_files)
parser.add_argument('--submission_file',
                    help='where to write prediction output')
args = parser.parse_args()

model, cost_obj = create_network()

assert 'all' in args.manifest, "Missing train manifest"
assert 'test' in args.manifest, "Missing test manifest"
assert args.submission_file is not None, "Must supply a submission file to output scores to"

neon_logger.display('Performing train and test in submission mode')
train = make_train_loader(args.manifest['all'],
                          args.manifest_root,
                          model.be,
                          noise_file=args.manifest.get('noise'))
test = make_test_loader(args.manifest['test'], args.manifest_root, model.be)

model.fit(dataset=train,
          cost=cost_obj,
          optimizer=Adadelta(),
          num_epochs=args.epochs,
          callbacks=Callbacks(model, **args.callback_args))

preds = model.get_outputs(test)
np.savetxt(args.submission_file, preds[:, 1], fmt='%.5f')
예제 #5
0
파일: train.py 프로젝트: StevenLOL/neon
from neon.optimizers import RMSProp
from neon.transforms import Misclassification
from neon.callbacks.callbacks import Callbacks
from network import create_network
from data import make_train_loader, make_val_loader

eval_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'whale_eval.cfg')
config_files = [eval_config] if os.path.exists(eval_config) else []
parser = NeonArgparser(__doc__, default_config_files=config_files)
args = parser.parse_args()

model, cost_obj = create_network()

assert 'train' in args.manifest, "Missing train manifest"
assert 'val' in args.manifest, "Missing val manifest"

train = make_train_loader(args.manifest['train'], args.manifest_root, model.be,
                          noise_file=args.manifest.get('noise'))

neon_logger.display('Performing train and test in validation mode')
val = make_val_loader(args.manifest['val'], args.manifest_root, model.be)
metric = Misclassification()

model.fit(dataset=train,
          cost=cost_obj,
          optimizer=RMSProp(learning_rate=1e-4),
          num_epochs=args.epochs,
          callbacks=Callbacks(model, eval_set=val, metric=metric, **args.callback_args))

neon_logger.display('Misclassification error = %.1f%%' % (model.eval(val, metric=metric) * 100))