예제 #1
0
def main():
    logger.debug('copying and importing client module')

    script_path = fs_tracker.get_artifact('clientscript')
    script_name = os.path.basename(script_path)
    new_script_path = os.path.join(os.getcwd(), '_clientscript.py')
    shutil.copy(script_path, new_script_path)
    script_path = new_script_path
    logger.debug("script path: " + script_path)

    mypath = os.path.dirname(script_path)
    sys.path.append(mypath)
    # os.path.splitext(os.path.basename(script_path))[0]
    module_name = '_clientscript'

    client_module = importlib.import_module(module_name)
    logger.debug('loading args')
    with open(fs_tracker.get_artifact('args')) as f:
        args = pickle.loads(f.read())

    logger.debug('getting file mappings')
    artifacts = fs_tracker.get_artifacts()

    logger.debug('calling client funciton')
    retval = client_module.clientFunction(args, artifacts)

    logger.debug('saving the return value')
    with open(fs_tracker.get_artifact('retval'), 'w') as f:
        f.write(pickle.dumps(retval))
예제 #2
0
    def pack_weights(self, persist_weights, model, metrics, verbose):
        """
        Save weights if persist_weights flag is True
        """

        if model is None:
            # This happens if a ResourceExhaustedError is caught
            metrics['weights_l2norm'] = None
            return

        if persist_weights:
            metrics['weights_l2norm'] = self.l2norm(model.get_weights())
            weights_file = os.path.join(get_artifact('modeldir'),
                                        self.get_model_name())
            if verbose:
                print("Saving weights file to {}".format(weights_file))
            model.save(weights_file)
            if verbose:
                print("Saving complete")
예제 #3
0
    def unpack_weights(self, verbose):
        """
        Load weights if present
        """

        try:
            weights_file = os.path.join(get_artifact('modeldir'),
                                        self.get_model_name())
            if verbose:
                print("Loading weights from {}".format(weights_file))
            weights = self.load_model_weights(weights_file)
            if verbose:
                print("Loaded successfully, L2 norm of weights = {}".format(
                    self.l2norm(weights)))
            return weights
        except BaseException as exception:
            if verbose:
                print("Weight loading failed due to {}".format(exception))
                print("unpack_weights returns None")
            return None
예제 #4
0
def clientFunction(args, files):
    print('client function call with args ' + str(args) + ' and files ' +
          str(files))

    modelfile = 'model.dat'
    filename = files.get('model') or \
        os.path.join(fs_tracker.get_artifact('modeldir'), modelfile)

    print("Trying to load file {}".format(filename))

    if os.path.exists(filename):
        with open(filename, 'rb') as f:
            args = pickle.loads(f.read()) + 1

    else:
        print("Trying to write file {}".format(filename))
        with open(filename, 'wb') as f:
            f.write(pickle.dumps(args, protocol=2))

    return args
예제 #5
0
파일: serve_main.py 프로젝트: zuma89/studio
def main():
    argparser = argparse.ArgumentParser(description='Serve studio model')

    argparser.add_argument('--wrapper',
                           '-w',
                           help='python script with function create_model ' +
                           'that takes modeldir '
                           '(that is, directory where experiment saves ' +
                           'the checkpoints etc)' +
                           'and returns dict -> dict function (model).' +
                           'By default, studio-serve will try to determine ' +
                           'this function automatically.',
                           default=None)

    argparser.add_argument('--port',
                           help='port to run Flask server on',
                           type=int,
                           default=5000)

    argparser.add_argument('--host', help='host name.', default='0.0.0.0')

    argparser.add_argument(
        '--killafter',
        help='Shut down after this many seconds of inactivity',
        default=3600)

    options = argparser.parse_args(sys.argv[1:])

    global model

    modeldir = fs_tracker.get_artifact('modeldata')
    if options.wrapper:
        module_name = re.sub('.py\Z', '', options.wrapper)
        wrapper_module = importlib.import_module(module_name)
        model = wrapper_module.create_model(modeldir)
    else:
        model = auto_generate_model(modeldir)

    restart_killtimer(int(options.killafter))
    app.run(host=options.host, port=options.port)
def main():
    logger.setLevel(logs.DEBUG)
    logger.debug('copying and importing client module')
    logger.debug('getting file mappings')

    artifacts = fs_tracker.get_artifacts()
    files = {}
    logger.debug("Artifacts = {}".format(artifacts))

    for tag, path in six.iteritems(artifacts):
        if tag not in {'workspace', 'modeldir', 'tb', '_runner'}:
            if os.path.isfile(path):
                files[tag] = path
            elif os.path.isdir(path):
                dirlist = os.listdir(path)
                if any(dirlist):
                    files[tag] = os.path.join(path, dirlist[0])

    logger.debug("Files = {}".format(files))
    script_path = files['clientscript']

    # script_name = os.path.basename(script_path)
    new_script_path = os.path.join(os.getcwd(), '_clientscript.py')
    shutil.copy(script_path, new_script_path)

    script_path = new_script_path
    logger.debug("script path: " + script_path)

    mypath = os.path.dirname(script_path)
    sys.path.append(mypath)
    # os.path.splitext(os.path.basename(script_path))[0]
    module_name = '_clientscript'

    client_module = importlib.import_module(module_name)
    logger.debug('loading args')

    args_path = files['args']

    with open(args_path, 'rb') as f:
        args = pickle.loads(f.read())

    logger.debug('calling client funciton')
    retval = client_module.clientFunction(args, files)

    logger.debug('saving the return value')
    retval_path = fs_tracker.get_artifact('retval')
    if os.path.isdir(fs_tracker.get_artifact('clientscript')):
        # on go runner:
        logger.debug("Running in a go runner, creating {} for retval".format(
            retval_path))
        try:
            os.mkdir(retval_path)
        except OSError:
            logger.debug('retval dir present')

        retval_path = os.path.join(retval_path, 'retval')
        logger.debug("New retval_path is {}".format(retval_path))

    logger.debug('Saving retval')
    with open(retval_path, 'wb') as f:
        f.write(pickle.dumps(retval, protocol=2))
    logger.debug('Done')
예제 #7
0
from studio import fs_tracker
import numpy as np

if fs_tracker.get_artifact('lr') is not None:
    lr = np.load(fs_tracker.get_artifact('lr'))
else:
    lr = np.random.random(10)

print "fitness: %s" % np.abs(np.sum(lr))
예제 #8
0
no_samples = 100
dim_samples = 5

learning_rate = 0.01
no_steps = 10

X = np.random.random((no_samples, dim_samples))
y = np.random.random((no_samples, ))

w = np.random.random((dim_samples, ))

for step in range(no_steps):
    yhat = X.dot(w)
    err = (yhat - y)
    dw = err.dot(X)
    w -= learning_rate * dw
    loss = 0.5 * err.dot(err)

    print("step = {}, loss = {}, L2 norm = {}".format(step, loss, w.dot(w)))

    #    with open(os.path.expanduser('~/weights/lr_w_{}_{}.pck'
    #                                 .format(step, loss)), 'w') as f:
    #        f.write(pickle.dumps(w))

    from studio import fs_tracker
    with open(
            os.path.join(fs_tracker.get_artifact('weights'),
                         'lr_w_{}_{}.pck'.format(step, loss)), 'w') as f:
        f.write(pickle.dumps(w))
예제 #9
0
from studio import fs_tracker
import numpy as np

try:
    lr = np.load(fs_tracker.get_artifact('lr'))
except BaseException:
    lr = np.random.random(10)

print("fitness: %s" % np.abs(np.sum(lr)))
예제 #10
0
#!/usr/bin/env python
import os
import matplotlib.image as mpimg
import cv2
import numpy as np
import torch
import subprocess
import tempfile

try:
    from studio import fs_tracker
    DEFAULT_ZIP_PATH = fs_tracker.get_artifact('data')
    # IMG_DIR = fs_tracker.get_artifact('data')

except ImportError:
    fs_tracker = None
    DEFAULT_ZIP_PATH = 'data/img_align_celeba_attr.zip'

IMG_DIR = os.path.join(tempfile.gettempdir(), 'data')
DEFAULT_ATTR_PATH = os.path.join(IMG_DIR, 'attributes.txt')

IMG_ZIP_PATH = os.environ.get('IMG_ZIP_PATH', DEFAULT_ZIP_PATH)

IMG_ATTR_PATH = os.environ.get('IMG_ATTR_PATH', DEFAULT_ATTR_PATH)
IMG_SIZE = 128

IMG_PATH = os.path.join(IMG_DIR, 'images_%i_%i.pth' % (IMG_SIZE, IMG_SIZE))
IMG20K_PATH = os.path.join(IMG_DIR,
                           'images_%i_%i_20000.pth' % (IMG_SIZE, IMG_SIZE))
ATTR_PATH = os.path.join(IMG_DIR, 'attributes.pth')
import sys
from studio import fs_tracker

print(fs_tracker.get_artifact('f'))
with open(fs_tracker.get_artifact('f'), 'r') as f:
    print(f.read())

if len(sys.argv) > 1:
    with open(fs_tracker.get_artifact('f'), 'w') as f:
        f.write(sys.argv[1])

sys.stdout.flush()
예제 #12
0
params = parser.parse_args()

params.img_sz = preprocess.IMG_SIZE

# check parameters
check_attr(params)
assert len(params.name.strip()) > 0
assert params.n_skip <= params.n_layers - 1
assert params.deconv_method in ['convtranspose', 'upsampling', 'pixelshuffle']
assert 0 <= params.smooth_label < 0.5
assert not params.ae_reload or os.path.isfile(params.ae_reload)
assert not params.lat_dis_reload or os.path.isfile(params.lat_dis_reload)
assert not params.ptc_dis_reload or os.path.isfile(params.ptc_dis_reload)
assert not params.clf_dis_reload or os.path.isfile(params.clf_dis_reload)

eval_clf_artifact = fs_tracker.get_artifact('eval_clf') if fs_tracker else '.'

if params.eval_clf != '':
    eval_clf_path = os.path.join(eval_clf_artifact, params.eval_clf)
else:
    eval_clf_path = eval_clf_artifact

assert os.path.isfile(eval_clf_path)

assert params.lambda_lat_dis == 0 or params.n_lat_dis > 0
assert params.lambda_ptc_dis == 0 or params.n_ptc_dis > 0
assert params.lambda_clf_dis == 0 or params.n_clf_dis > 0

# initialize experiment / load dataset
logger = initialize_exp(params)
data, attributes = load_images(params)
예제 #13
0
import glob
import os
from studio import fs_tracker
import pickle

weights_list = sorted(
    glob.glob(os.path.join(fs_tracker.get_artifact('w'), '*.pck')))

print('*****')
print(weights_list[-1])
with open(weights_list[-1], 'r') as f:
    w = pickle.load(f)

print w.dot(w)
print('*****')
예제 #14
0
parser.add_argument(
    "--debug",
    type=bool_flag,
    default=True,
    help="Debug mode (only load a subset of the whole dataset)")

params = parser.parse_args()

# check parameters
assert params.n_images >= 1 and params.n_interpolations >= 2

# create logger / load trained model
logger = create_logger(None)

if fs_tracker:
    model_path = os.path.join(fs_tracker.get_artifact('model'),
                              params.model_path)
else:
    model_path = params.model_path

assert os.path.isfile(model_path), "model_path {} is not a file".format(
    model_path)
ae = torch.load(model_path, map_location=lambda storage, loc: storage).eval()

# restore main parameters
params.debug = False
params.batch_size = 32
params.v_flip = False
params.h_flip = False
params.img_sz = ae.img_sz
params.attr = ae.attr