コード例 #1
0
ファイル: trainer.py プロジェクト: sidnarayanan/SubtLeNet
#!/usr/bin/env python2.7

from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('--nepoch', type=int, default=50)
parser.add_argument('--version', type=str, default='4')
parser.add_argument('--trunc', type=int, default=7)
parser.add_argument('--limit', type=int, default=100)
parser.add_argument('--train_baseline', action='store_true')
args = parser.parse_args()

import extra_vars
from subtlenet.models import particles as train
from os import path

train.NEPOCH = args.nepoch
train.VERSION = args.version
data, dims = train.instantiate(args.trunc, args.limit)

clf_gen = train.setup_data(data)

clf = train.build_classifier(dims)

train.train(clf, 'baseline', clf_gen['train'], clf_gen['validation'])
コード例 #2
0
            'w_clf' : 0.001,
            'w_adv' : 0.1,
            }
else:
    opts = {
            'loss' : args.adv,
            'scale' : 0.1,
            'w_clf' : 0.001,
            'w_adv' : 1,
            }

clf = train.build_classifier(dims)
if args.adv is not None:
    adv = train.build_adversary(clf=clf, **opts)

preload = '%s/%s/baseline_best.h5'%(train.MODELDIR, train._APOSTLE)
if path.isfile(preload):
    print 'Pre-loading weights from',preload
    tmp_ = train.load_model(preload)
    clf.set_weights(tmp_.get_weights())
if args.train_baseline or not(path.isfile(preload)):
    train.train(clf, 'baseline', clf_gen['train'], clf_gen['validation'])

if args.adv:
    print 'Training the full adversarial stack:'
    callback_params = {
            'partial_model' : clf,
            'monitor' : lambda x : opts['w_clf'] * x.get('val_y_hat_loss') - opts['w_adv'] * x.get('val_adv_loss'), # semi-arbitrary
            }
    train.train(adv, args.adv, adv_gen['train'], adv_gen['validation'], callback_params)
コード例 #3
0
#!/usr/bin/env python2.7

from subtlenet.models import particles as train
from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument('--nepoch', type=int, default=50)
parser.add_argument('--version', type=int, default=4)
parser.add_argument('--trunc', type=int, default=4)
parser.add_argument('--limit', type=int, default=50)
args = parser.parse_args()

train.NEPOCH = args.nepoch
train.VERSION = args.version
data, dims = train.instantiate(args.trunc, args.limit)

clf_gen = train.setup_data(data)

clf = train.build_classifier(dims)

train.train(clf, 'classifier', clf_gen['train'], clf_gen['validation'])