Пример #1
0
 def _parse_json(self, json_response):
     self._data = DotDict(json_response)
     if not self._data.id:  # TODO: Should be able to pass non-preserved api JSON object as well ?
         print("WARN: Missing User ID for parsing data")
         self._data.id = None
         # raise Exception(
         #     "JSON response should have id property for a valid Resource")
     for key, val in self._data.items():
         self.__setattr__(key, val)
Пример #2
0
    def __init__(self, hp, net_arch, loss_f, rank=0, world_size=1):
        self.hp = hp
        self.device = self.hp.model.device
        self.net = net_arch.to(self.device)
        self.rank = rank
        self.world_size = world_size
        if self.device != "cpu" and self.world_size != 0:
            self.net = DDP(self.net, device_ids=[self.rank])
        self.input = None
        self.GT = None
        self.step = 0
        self.epoch = -1

        # init optimizer
        optimizer_mode = self.hp.train.optimizer.mode
        if optimizer_mode == "adam":
            self.optimizer = torch.optim.Adam(
                self.net.parameters(),
                **(self.hp.train.optimizer[optimizer_mode]))
        else:
            raise Exception("%s optimizer not supported" % optimizer_mode)

        # init loss
        self.loss_f = loss_f
        self.log = DotDict()
Пример #3
0
class Resource(object):
    """
    Parent for API and Endpoint resources , implement generalize delete type method here
    """
    BASE = "https://{host}:{port}/api/am/".format(
        host=default['connection']['hostname'],
        port=default['connection']['port'])
    VERSION = 1.0
    APPLICATION = "publisher"
    RESOURCE = ''

    def __init__(self):
        self._data = None  # raw Resource data
        self.id = None  # Resource ID, None until persist via REST API, call save() will set ID
        self.client = None

    def delete(self):
        pass

    def _parse_json(self, json_response):
        self._data = DotDict(json_response)
        if not self._data.id:  # TODO: Should be able to pass non-preserved api JSON object as well ?
            print("WARN: Missing API ID for parsing data")
            self._data.id = None
            # raise Exception(
            #     "JSON response should have id property for a valid Resource")
        for key, val in self._data.items():
            if key in self._parsers:
                self._parsers[key](val)
                continue
            self.__setattr__(key, val)

    @property
    def _parsers(self):
        raise NotImplemented()

    @classmethod
    def get_endpoint(cls):
        return cls.BASE + "{application}/v{api_version}".format(
            api_version=cls.VERSION,
            application=cls.APPLICATION) + cls.RESOURCE

    @classmethod
    def get(cls, resource_id, client=RestClient()):
        res = client.session.get(cls.get_endpoint() +
                                 "/{}".format(resource_id),
                                 verify=client.verify)
        if res.status_code != 200:
            print("Warning Error while getting the Resource {}\nERROR: {}".
                  format(resource_id, res.content))
        print("Status code: {}".format(res.status_code))
        return cls(**res.json())

    @staticmethod
    def all():
        pass

    @staticmethod
    def delete_all(self):
        pass
Пример #4
0
def get_rotated_mnist_dataloader(cfg,
                                 split='train',
                                 filter_obj=None,
                                 batch_size=128,
                                 task_num=10,
                                 *args,
                                 **kwargs):
    d = DotDict()
    fuzzy = get_config_attr(cfg, 'EXTERNAL.OCL.FUZZY', default=0, mute=True)
    global _rmnist_loaders
    if not _rmnist_loaders:
        data = get_rotated_mnist(d)
        #train_loader, val_loader, test_loader = [CLDataLoader(elem, batch_size, train=t) \
        #                                         for elem, t in zip(data, [True, False, False])]
        loader_cls = CLDataLoader if not fuzzy else FuzzyCLDataLoader
        train_loader, val_loader, test_loader = [loader_cls(elem, batch_size, train=t) \
                                                 for elem, t in zip(data, [True, False, False])]
        _rmnist_loaders = train_loader, val_loader, test_loader
    else:
        train_loader, val_loader, test_loader = _rmnist_loaders
    if split == 'train':
        return train_loader[filter_obj[0]]
    elif split == 'val':
        return val_loader[filter_obj[0]]
    elif split == 'test':
        return test_loader[filter_obj[0]]
Пример #5
0
class Application(Resource):
    CONST = DotDict(application_const)
    RESOURCE = "/applications"
    APPLICATION = "store"

    def __init__(self, name, description, throttlingTier='', **kwargs):
        super().__init__()
        self.name = name
        self.description = description
        self.throttlingTier = throttlingTier

        self.subscriber = None
        self.permission = "[{\"groupId\" : 1000, \"permission\" : [\"READ\",\"UPDATE\"]},{\"groupId\" : 1001, \"permission\" : [\"READ\",\"UPDATE\"]}]"
        self.lifeCycleStatus = None
        self.keys = []
        if kwargs:
            self._parse_json(kwargs)

    def save(self):
        headers = {
            'Accept': 'application/json',
            'Content-Type': 'application/json'
        }
        data = self.to_json()
        res = self.client.session.post(Application.get_endpoint(), data=json.dumps(data), verify=self.client.verify,
                                       headers=headers)
        print("Status code: {}".format(res.status_code))
        if res.status_code != 201:
            print(res.json())
            raise Exception("Fail to save the global endpoint via REST API")
        self._data = DotDict(res.json())
        self.id = self._data.applicationId
        return self

    @staticmethod
    def delete_all(client=RestClient()):
        res = client.session.get(Application.get_endpoint(), verify=client.verify)
        if not res.status_code == 200:
            print(res.json())
            raise Exception("Error getting Applications list")
        print("Status code: {}".format(res.status_code))
        apps_list = res.json()['list']
        for app in apps_list:
            res = client.session.delete(Application.get_endpoint() + "/{}".format(app['applicationId']), verify=client.verify)
            if res.status_code != 200:
                print("Warning Error while deleting the API {}\nERROR: {}".format(app['name'], res.content))
            print("Status code: {}".format(res.status_code))

    # TODO: Make this a generic to_json method for all Resources
    def to_json(self):
        temp = {
            'name': self.name,
            'description': self.description,
            'throttlingTier': self.throttlingTier
        }
        return temp

    def set_rest_client(self, client=RestClient(APPLICATION)):
        self.client = client
Пример #6
0
    def __init__(self, name, version, context, service_url=None, **kwargs):
        super().__init__()
        self.client = None  # REST API Client
        self.id = None
        self._data = None

        self.name = name
        self.context = context
        self.version = version
        self.endpoint = DotDict()
        if service_url:
            inline_endpoint_name = "{}_{}".format(Endpoint.CONST.TYPES.SANDBOX,
                                                  name)
            self.endpoint[Endpoint.CONST.TYPES.SANDBOX] = Endpoint(
                inline_endpoint_name, 'http', service_url)
            inline_endpoint_name = "{}_{}".format(
                Endpoint.CONST.TYPES.PRODUCTION, name)
            self.endpoint[Endpoint.CONST.TYPES.PRODUCTION] = Endpoint(
                inline_endpoint_name, 'http', service_url)
        if kwargs:
            self._parse_json(kwargs)
Пример #7
0
 def save(self):
     headers = {
         'Accept': 'application/json',
         'Content-Type': 'application/json'
     }
     data = self.to_json()
     res = self.client.session.post(Application.get_endpoint(), data=json.dumps(data), verify=self.client.verify,
                                    headers=headers)
     print("Status code: {}".format(res.status_code))
     if res.status_code != 201:
         print(res.json())
         raise Exception("Fail to save the global endpoint via REST API")
     self._data = DotDict(res.json())
     self.id = self._data.applicationId
     return self
Пример #8
0
    def __init__(self, hp, net_arch, loss_f):
        self.hp = hp
        self.device = hp.model.device
        self.net = net_arch.to(self.device)
        self.input = None
        self.GT = None
        self.step = 0
        self.epoch = -1

        # init optimizer
        optimizer_mode = hp.train.optimizer.mode
        if optimizer_mode == "adam":
            self.optimizer = torch.optim.Adam(
                self.net.parameters(), **(hp.train.optimizer[optimizer_mode]))
        else:
            raise Exception("%s optimizer not supported" % optimizer_mode)

        # init loss
        self.loss_f = loss_f
        self.log = DotDict()
Пример #9
0
def get_split_mini_imagenet_dataloader(cfg,
                                       split='train',
                                       filter_obj=None,
                                       batch_size=128,
                                       *args,
                                       **kwargs):
    global _cache_mini_imagenet
    d = DotDict()
    fuzzy = get_config_attr(cfg, 'EXTERNAL.OCL.FUZZY', default=0, mute=True)
    if not _cache_mini_imagenet:
        data = get_miniimagenet(d)
        loader_cls = CLDataLoader if not fuzzy else FuzzyCLDataLoader
        train_loader, val_loader, test_loader = [loader_cls(elem, batch_size, train=t) \
                                                 for elem, t in zip(data, [True, False, False])]
        _cache_mini_imagenet = train_loader, val_loader, test_loader
    train_loader, val_loader, test_loader = _cache_mini_imagenet
    if split == 'train':
        return train_loader[filter_obj[0]]
    elif split == 'val':
        return val_loader[filter_obj[0]]
    elif split == 'test':
        return test_loader[filter_obj[0]]
Пример #10
0
def get_split_cifar100_dataloader(cfg,
                                  split='train',
                                  filter_obj=None,
                                  batch_size=128,
                                  *args,
                                  **kwargs):
    d = DotDict()
    fuzzy = get_config_attr(cfg, 'EXTERNAL.OCL.FUZZY', default=0, mute=True)
    global _cache_cifar100
    if not _cache_cifar100:
        data = get_split_cifar100(
            d, cfg
        )  #ds_cifar10and100(batch_size=batch_size, num_workers=0, cfg=cfg, **kwargs)
        loader_cls = CLDataLoader if not fuzzy else FuzzyCLDataLoader
        train_loader, val_loader, test_loader = [loader_cls(elem, batch_size, train=t) \
                                                 for elem, t in zip(data, [True, False, False])]
        _cache_cifar100 = train_loader, val_loader, test_loader
    train_loader, val_loader, test_loader = _cache_cifar100
    if split == 'train':
        return train_loader[filter_obj[0]]
    elif split == 'val':
        return val_loader[filter_obj[0]]
    elif split == 'test':
        return test_loader[filter_obj[0]]
Пример #11
0
class API(Resource):
    RESOURCE = "/apis"

    def __init__(self, name, version, context, service_url=None, **kwargs):
        super().__init__()
        self.client = None  # REST API Client
        self.id = None
        self._data = None

        self.name = name
        self.context = context
        self.version = version
        self.endpoint = DotDict()
        if service_url:
            inline_endpoint_name = "{}_{}".format(Endpoint.CONST.TYPES.SANDBOX,
                                                  name)
            self.endpoint[Endpoint.CONST.TYPES.SANDBOX] = Endpoint(
                inline_endpoint_name, 'http', service_url)
            inline_endpoint_name = "{}_{}".format(
                Endpoint.CONST.TYPES.PRODUCTION, name)
            self.endpoint[Endpoint.CONST.TYPES.PRODUCTION] = Endpoint(
                inline_endpoint_name, 'http', service_url)
        if kwargs:
            self._parse_json(kwargs)

    def set_rest_client(self, client=RestClient()):
        self.client = client

    def save(self):
        if self.id:
            print("WARN: API is already persist")
            return self
        res = self.client.session.post(API.get_endpoint(),
                                       json=self.to_json(),
                                       verify=self.client.verify)
        if res.status_code != 201:
            print(res.content)
            print(res.status_code)
            raise Exception("An error occurred while creating an API CODE: " +
                            str(res.status_code))
        self._parse_json(res.json())
        print("Status code: {}".format(res.status_code))
        return self

    def to_json(self):
        temp = {
            'name': self.name,
            'version': self.version,
            'context': self.context,
        }
        keys = self._data.keys()
        for key, value in self.__dict__.items():
            if key in keys:
                temp[key] = value
        endpoints = []
        for epType, endpoint in self.endpoint.items():
            serialize_endpoint = {"type": epType}
            if endpoint.id:
                serialize_endpoint['key'] = endpoint.id
            else:
                serialize_endpoint['inline'] = endpoint.to_json()
            endpoints.append(serialize_endpoint)
        if len(endpoints):
            temp['endpoint'] = endpoints
        return temp

    def _parse_endpoint(self, endpoint_json):
        for endpoint in endpoint_json:
            endpoint_json = endpoint.get('inline') or endpoint.get('key')
            self.endpoint[endpoint['type']] = Endpoint(**endpoint_json)

    def set_endpoint(self, endpoint_type, endpoint):
        if endpoint_type not in Endpoint.CONST.TYPES.values():
            raise Exception("Endpoint type should be one of these {}".format(
                Endpoint.CONST.TYPES.values()))
        if type(endpoint) is not Endpoint:
            raise Exception("endpoint should be an instance of Endpoint")
        if endpoint_type != Endpoint.CONST.TYPES.INLINE and not endpoint.id:
            raise Exception(
                "Global endpoint should have persist before mapping it to an API"
            )
        self.endpoint[endpoint_type] = endpoint
        self._data.endpoint = self.endpoint
        res = self.client.session.put(API.get_endpoint() + "/" + self.id,
                                      json=self.to_json(),
                                      verify=self.client.verify)
        if res.status_code != 200:
            print("Something went wrong when updating endpoints")
            print(res)
        return self

    def set_policies(self, policy_data):
        self.policies = self.policies if self.policies else []
        if type(policy_data) is list:
            self.policies.extend(policy_data)
        else:
            self.policies.append(policy_data)
        self._data["policies"] = self.policies
        res = self.client.session.put(API.get_endpoint() + "/" + self.id,
                                      json=self.to_json(),
                                      verify=self.client.verify)
        if res.status_code != 200:
            print("Something went wrong when updating policies")
            print(res)
        return self

    def delete(self):
        res = self.client.session.delete(API.get_endpoint() +
                                         "/{}".format(self.id),
                                         verify=self.client.verify)
        if res.status_code != 200:
            print("Warning Error while deleting the API {}\nERROR: {}".format(
                self.name, res.content))
        print("Status code: {}".format(res.status_code))

    def change_lc(self, state):
        api_id = self.id
        data = {"action": state, "apiId": api_id}
        lcs = self.client.session.get(API.get_endpoint() +
                                      "/{apiId}/lifecycle".format(
                                          apiId=api_id))
        if lcs.ok:
            lcs = lcs.json()
            available_transitions = [
                current_state
                for current_state in lcs['availableTransitionBeanList']
                if current_state['targetState'] == state
            ]
            if len(available_transitions) == 1:
                res = self.client.session.post(API.get_endpoint() +
                                               "/change-lifecycle",
                                               params=data,
                                               verify=self.client.verify)
                print("Status code: {}".format(res.status_code))
            else:
                raise ("Invalid transition state valid ones are = {}".format(
                    lcs['availableTransitionBeanList']))
        else:
            raise ("Can't find Lifecycle for the api {}".format(api_id))

    @property
    def _parsers(self):
        parsers = {  # if need special parsing other than simply assigning as attribute
            'endpoint': self._parse_endpoint
        }
        return parsers

    @staticmethod
    def delete_all(client=RestClient()):
        res = client.session.get(API.get_endpoint(), verify=client.verify)
        if not res.status_code == 200:
            print(res.json())
            raise Exception("Error getting APIs list")
        print("Status code: {}".format(res.status_code))
        apis_list = res.json()['list']
        for api in apis_list:
            res = client.session.delete(API.get_endpoint() +
                                        "/{}".format(api['id']),
                                        verify=client.verify)
            if res.status_code != 200:
                print("Warning Error while deleting the API {}\nERROR: {}".
                      format(api['name'], res.content))
            print("Status code: {}".format(res.status_code))

    @staticmethod
    def all(client=RestClient()):
        res = client.session.get(API.get_endpoint(), verify=client.verify)
        if not res.status_code == 200:
            print(res.json())
            raise Exception("Error getting APIs list")
        print("Status code: {}".format(res.status_code))
        return [API(**api) for api in res.json()['list']]
Пример #12
0
def train_loop(rank, hp, world_size=1):
    # reload hp
    hp = DotDict(hp)
    if hp.model.device.lower() == "cuda" and world_size != 0:
        setup(hp, rank, world_size)
    if rank != 0:
        logger = None
        writer = None
    else:
        # set logger
        logger = make_logger(hp)
        # set writer (tensorboard / wandb)
        writer = Writer(hp, hp.log.log_dir)
        hp_str = yaml.dump(hp.to_dict())
        logger.info("Config:")
        logger.info(hp_str)
        if hp.data.train_dir == "" or hp.data.test_dir == "":
            logger.error("train or test data directory cannot be empty.")
            raise Exception("Please specify directories of data")
        logger.info("Set up train process")

    if hp.model.device.lower() == "cuda" and world_size != 0:
        hp.model.device = rank
        torch.cuda.set_device(rank)
    else:
        hp.model.device = hp.model.device.lower()

    # make dataloader
    if logger is not None:
        logger.info("Making train dataloader...")
    train_loader = create_dataloader(hp, DataloaderMode.train, rank,
                                     world_size)
    if logger is not None:
        logger.info("Making test dataloader...")
    test_loader = create_dataloader(hp, DataloaderMode.test, rank, world_size)

    # init Model
    net_arch = Net_arch(hp)
    loss_f = torch.nn.MSELoss()
    model = Model(hp, net_arch, loss_f, rank, world_size)

    # load training state
    if hp.load.resume_state_path is not None:
        model.load_training_state(logger)
    else:
        if logger is not None:
            logger.info("Starting new training run.")

    try:
        epoch_step = 1 if hp.data.divide_dataset_per_gpu else world_size
        for model.epoch in itertools.count(model.epoch + 1, epoch_step):
            if model.epoch > hp.train.num_iter:
                break
            train_model(hp, model, train_loader, writer, logger)
            if model.epoch % hp.log.chkpt_interval == 0:
                model.save_network(logger)
                model.save_training_state(logger)
            test_model(hp, model, test_loader, writer)
        cleanup()
        if logger is not None:
            logger.info("End of Train")
    except Exception as e:
        if logger is not None:
            logger.info("Exiting due to exception: %s" % e)
        traceback.print_exc()
        cleanup()
Пример #13
0
def _worker_init(hdf_file_name, devices, lock, generations, nn_class, params, player_change_callback=None):
    global _env_
    import multiprocessing as mp    
    import time
    import self_play
    from nn import NeuralNetWrapper
    from utils.proxies import AsyncBatchedProxy

    pid = mp.current_process().pid

    _env_ = DotDict({})

    if not isinstance(nn_class, (list,tuple)):
        nn_class, generations, params = (nn_class,), (generations,), (params,)
    else:
        _env_.compare_models = True
    assert len(nn_class) == len(generations) == len(params)

    pytorch_device = devices[pid%len(devices)]
    players_params = {}
    models = {}
    for i in range(len(generations)):
        models[i] = nn_class[i](params[i])
        if generations[i] != 0:
            models[i].load_parameters(generations[i]-(1 if len(nn_class)==1 else 0), to_device=pytorch_device)
        players_params[i] = params[i]

    if len(models) == 1:
        models[1] = models[0]
        players_params[1] = players_params[0]
        generations = [generations[0], generations[0]]

    players_params[0].nn.pytorch_device = pytorch_device
    players_params[1].nn.pytorch_device = pytorch_device

    # shuffle the players based on the pid, important for computing the Elo score
    if pid % 2 != 0:
        tmp = models[0]
        models[0] = models[1]
        models[1] = tmp
        tmp = players_params[0]
        players_params[0] = players_params[1]
        players_params[1] = players_params[0]
        generations = list(reversed(generations))

    _env_.models = models
    _env_.players_params = DotDict(players_params)
    _env_.generations = generations
    _env_.params = players_params[0]
    _env_.player_change_callback = player_change_callback
    _env_.name = 'w%i' % pid
    _env_.hdf_file_name = hdf_file_name
    _env_.hdf_lock = lock

    logger.info("Worker %s uses device %s", _env_.name, _env_.params.nn.pytorch_device)

    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)            
    
    sp_params = _env_.params.self_play
    _env_.nn_wrapper = NeuralNetWrapper(None, _env_.params)
    _env_.nnet = AsyncBatchedProxy(_env_.nn_wrapper, batch_size=sp_params.nn_batch_size, 
                                   timeout=sp_params.nn_batch_timeout, 
                                   batch_builder=sp_params.nn_batch_builder,
                                   cache_size = 0 if _env_.compare_models else 400000)
    _env_.tasks = []
    fut = asyncio.ensure_future(_env_.nnet.run(), loop=loop)
    fut.add_done_callback(_fut_cb) # re-raise exception if occured in nnet
    _env_.tasks.append(fut)
Пример #14
0
class User(object):
    def __init__(self, userName, password=None, **kwargs):
        self.auth = HTTPBasicAuth(default["admin"]["name"],
                                  default["admin"]["password"])
        self.userName = userName
        self.password = password
        self.api_endpoint = "https://{host}:{port}/api/identity/scim2/v1.0/Users".format(
            host=default['connection']['hostname'],
            port=default['connection']['port'])
        self.verify = False
        if kwargs:
            self._parse_json(kwargs)

    def serialize(self):
        data = {
            "schemas": [],
            "name": {
                "familyName": "{}".format(self.userName),
                "givenName": "sam"
            },
            "userName":
            "******".format(self.userName),
            "password":
            "******".format(self.userName),
            "emails": [{
                "primary": True,
                "value": "{}@gmail.com".format(self.userName),
                "type": "home"
            }, {
                "value": "{}[email protected]".format(self.userName),
                "type": "work"
            }]
        }
        if self.password:
            data["password"] = self.password

        return data

    def _parse_json(self, json_response):
        self._data = DotDict(json_response)
        if not self._data.id:  # TODO: Should be able to pass non-preserved api JSON object as well ?
            print("WARN: Missing User ID for parsing data")
            self._data.id = None
            # raise Exception(
            #     "JSON response should have id property for a valid Resource")
        for key, val in self._data.items():
            self.__setattr__(key, val)

    def save(self):
        response = requests.post(self.api_endpoint,
                                 json=self.serialize(),
                                 auth=self.auth,
                                 verify=self.verify)
        if not response.ok:
            print(response.content)
            print(response.status_code)
            raise Exception("An error occurred while creating an user: "******"{}/{}".format(self.api_endpoint, self.id),
                                   auth=self.auth,
                                   verify=self.verify)
        if not response.ok:
            print(response.content)
            print(response.status_code)
            raise Exception("An error occurred while creating an user: "******"https://{host}:{port}/api/identity/scim2/v1.0/Users".format(
            host=default['connection']['hostname'],
            port=default['connection']['port'])
        response = requests.get(api_endpoint,
                                auth=HTTPBasicAuth(
                                    default["admin"]["name"],
                                    default["admin"]["password"]),
                                verify=False)
        if not response.ok:
            print(response.content)
            print(response.status_code)
            raise Exception("An error occurred while getting all users: " +
                            str(response.status_code))
        return [User(**user) for user in response.json()['Resources']]

    @staticmethod
    def delete_all():
        for user in User.all():
            if user.userName.lower() == 'admin':
                print("WARN: Skip deleting admin user")
                continue
            print("DEBUG: deleting user {}".format(user.userName))
            user.delete()
Пример #15
0
simple = DotDict({
    "data_root": "data/_exp_",
    "hdf_file": "data/_exp_/sp_data.hdf",
    "tensorboard_log": "data/tboard/_exp_",
    "game": {
        "clazz": BoxesState,
        "init": partial(BoxesState.init_static_fields, ((3, 3),)),
    },
    "self_play": {
        "num_games": 2000,
        "n_workers": 20,
        "games_per_workers": 25,
        "reuse_mcts_tree": True,
        "noise": (0.8, 0.25),  # alpha, coeff
        "nn_batch_size": 48,
        "nn_batch_timeout": 0.05,
        "nn_batch_builder": nn_batch_builder,
        "pytorch_devices": ["cuda:1", "cuda:0"],  # get_cuda_devices_list(),
        "mcts": {
            "mcts_num_read": 800,
            "mcts_cpuct": (1.25, 19652),  # CPUCT, CPUCT_BASE
            "temperature": {0: 1.0, 12: 0.02},
            "max_async_searches": 64,
        }
    },
    "elo": {
        "hdf_file": "data/_exp_/elo_data.hdf",
        "n_games": 20,
        "n_workers": 10,
        "games_per_workers": 2,
        "self_play_override": {
            "reuse_mcts_tree": False,
            "noise": (0.0, 0.0),
            "mcts": {
                "mcts_num_read": 1200
            }
        }
    },
    "nn": {
        "model_class": SimpleNN,
        "pytorch_device": "cuda:0",
        "chkpts_filename": "data/_exp_/model_gen{}.pt",
        "train_params": {
            "pos_average": True,
            "symmetries": SymmetriesGenerator(),
            "nb_epochs": 10,
            "max_samples_per_gen":100*4096,  # approx samples for 10 generations
            "train_split": 0.9,
            "train_batch_size": 4096,
            "val_batch_size": 4096,
            "lr_scheduler": GenerationLrScheduler({0: 1e-2, 20: 1e-3, 50: 1e-4}),
            "lr": 1e-2,
            "optimizer_params": {
                "momentum": 0.9,
                "weight_decay": 1e-4,
            }
        },
        "model_parameters": None
    }
})
Пример #16
0
def train_net(params):
    # Initialize Parameters
    params = DotDict(params)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    verbose = {}
    verbose['loss_train'], verbose['loss_valid'], verbose['psnr_train'], verbose['psnr_valid'], \
        verbose['ssim_train'], verbose['ssim_valid'], verbose['vif_train'], verbose['vif_valid'] = ([] for i in range(8))

    log_metrics = True
    ssim_module = SSIM()
    msssim_module = MSSSIM()
    vifLoss = VIFLoss(sigma_n_sq=0.4, data_range=1.)
    msssimLoss = MultiScaleSSIMLoss(data_range=1.)
    best_validation_metrics = 100

    train_generator, val_generator = data_loaders(params)
    loaders = {"train": train_generator, "valid": val_generator}

    wnet_identifier = params.mask_URate[0:2] + "WNet_dense=" + str(int(params.dense)) + "_" + params.architecture + "_" \
                      + params.lossFunction + '_lr=' + str(params.lr) + '_ep=' + str(params.num_epochs) + '_complex=' \
                      + str(int(params.complex_net)) + '_' + 'edgeModel=' + str(int(params.edge_model)) \
                      + '(' + str(params.num_edge_slices) + ')_date=' + (datetime.now()).strftime("%d-%m-%Y_%H-%M-%S")

    if not os.path.isdir(params.model_save_path):
        os.mkdir(params.model_save_path)
    print("\n\nModel will be saved at:\n", params.model_save_path)
    print("WNet ID: ", wnet_identifier)

    wnet, optimizer, best_validation_loss, preTrainedEpochs = generate_model(
        params, device)

    # data = (iter(train_generator)).next()

    # Adding writer for tensorboard. Also start tensorboard, which tries to access logs in the runs directory
    writer = init_tensorboard(iter(train_generator), wnet, wnet_identifier,
                              device)

    for epoch in trange(preTrainedEpochs, params.num_epochs):
        for phase in ['train', 'valid']:
            if phase == 'train':
                wnet.train()
            else:
                wnet.eval()

            for i, data in enumerate(loaders[phase]):

                # for i in range(10000):
                x, y_true, _, _, fname, slice_num = data
                x, y_true = x.to(device, dtype=torch.float), y_true.to(
                    device, dtype=torch.float)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    y_pred = wnet(x)
                    if params.lossFunction == 'mse':
                        loss = F.mse_loss(y_pred, y_true)
                    elif params.lossFunction == 'l1':
                        loss = F.l1_loss(y_pred, y_true)
                    elif params.lossFunction == 'ssim':
                        # standard SSIM
                        loss = 0.16 * F.l1_loss(y_pred, y_true) + 0.84 * (
                            1 - ssim_module(y_pred, y_true))
                    elif params.lossFunction == 'msssim':
                        # loss = 0.16 * F.l1_loss(y_pred, y_true) + 0.84 * (1 - msssim_module(y_pred, y_true))
                        prediction_abs = torch.sqrt(
                            torch.square(y_pred[:, 0::2]) +
                            torch.square(y_pred[:, 1::2]))
                        target_abs = torch.sqrt(
                            torch.square(y_true[:, 0::2]) +
                            torch.square(y_true[:, 1::2]))
                        prediction_abs_flat = (torch.flatten(
                            prediction_abs, start_dim=0,
                            end_dim=1)).unsqueeze(1)
                        target_abs_flat = (torch.flatten(
                            target_abs, start_dim=0, end_dim=1)).unsqueeze(1)
                        loss = msssimLoss(prediction_abs_flat, target_abs_flat)
                    elif params.lossFunction == 'vif':
                        prediction_abs = torch.sqrt(
                            torch.square(y_pred[:, 0::2]) +
                            torch.square(y_pred[:, 1::2]))
                        target_abs = torch.sqrt(
                            torch.square(y_true[:, 0::2]) +
                            torch.square(y_true[:, 1::2]))
                        prediction_abs_flat = (torch.flatten(
                            prediction_abs, start_dim=0,
                            end_dim=1)).unsqueeze(1)
                        target_abs_flat = (torch.flatten(
                            target_abs, start_dim=0, end_dim=1)).unsqueeze(1)
                        loss = vifLoss(prediction_abs_flat, target_abs_flat)
                    elif params.lossFunction == 'mse+vif':
                        prediction_abs = torch.sqrt(
                            torch.square(y_pred[:, 0::2]) +
                            torch.square(y_pred[:, 1::2])).to(device)
                        target_abs = torch.sqrt(
                            torch.square(y_true[:, 0::2]) +
                            torch.square(y_true[:, 1::2])).to(device)
                        prediction_abs_flat = (torch.flatten(
                            prediction_abs, start_dim=0,
                            end_dim=1)).unsqueeze(1)
                        target_abs_flat = (torch.flatten(
                            target_abs, start_dim=0, end_dim=1)).unsqueeze(1)
                        loss = 0.15 * F.mse_loss(
                            prediction_abs_flat,
                            target_abs_flat) + 0.85 * vifLoss(
                                prediction_abs_flat, target_abs_flat)
                    elif params.lossFunction == 'l1+vif':
                        prediction_abs = torch.sqrt(
                            torch.square(y_pred[:, 0::2]) +
                            torch.square(y_pred[:, 1::2])).to(device)
                        target_abs = torch.sqrt(
                            torch.square(y_true[:, 0::2]) +
                            torch.square(y_true[:, 1::2])).to(device)
                        prediction_abs_flat = (torch.flatten(
                            prediction_abs, start_dim=0,
                            end_dim=1)).unsqueeze(1)
                        target_abs_flat = (torch.flatten(
                            target_abs, start_dim=0, end_dim=1)).unsqueeze(1)
                        loss = 0.146 * F.l1_loss(
                            y_pred, y_true) + 0.854 * vifLoss(
                                prediction_abs_flat, target_abs_flat)
                    elif params.lossFunction == 'msssim+vif':
                        prediction_abs = torch.sqrt(
                            torch.square(y_pred[:, 0::2]) +
                            torch.square(y_pred[:, 1::2])).to(device)
                        target_abs = torch.sqrt(
                            torch.square(y_true[:, 0::2]) +
                            torch.square(y_true[:, 1::2])).to(device)
                        prediction_abs_flat = (torch.flatten(
                            prediction_abs, start_dim=0,
                            end_dim=1)).unsqueeze(1)
                        target_abs_flat = (torch.flatten(
                            target_abs, start_dim=0, end_dim=1)).unsqueeze(1)
                        loss = 0.66 * msssimLoss(
                            prediction_abs_flat,
                            target_abs_flat) + 0.33 * vifLoss(
                                prediction_abs_flat, target_abs_flat)

                    if not math.isnan(loss.item()) and loss.item(
                    ) < 2 * best_validation_loss:  # avoid nan/spike values
                        verbose['loss_' + phase].append(loss.item())
                        writer.add_scalar(
                            'Loss/' + phase + '_epoch_' + str(epoch),
                            loss.item(), i)

                    if log_metrics and (
                        (i % params.verbose_gap == 0) or
                        (phase == 'valid' and epoch > params.verbose_delay)):
                        y_true_copy = y_true.detach().cpu().numpy()
                        y_pred_copy = y_pred.detach().cpu().numpy()
                        y_true_copy = y_true_copy[:, ::
                                                  2, :, :] + 1j * y_true_copy[:,
                                                                              1::
                                                                              2, :, :]
                        y_pred_copy = y_pred_copy[:, ::
                                                  2, :, :] + 1j * y_pred_copy[:,
                                                                              1::
                                                                              2, :, :]
                        if params.architecture[-1] == 'k':
                            # transform kspace to image domain
                            y_true_copy = np.fft.ifft2(y_true_copy,
                                                       axes=(2, 3))
                            y_pred_copy = np.fft.ifft2(y_pred_copy,
                                                       axes=(2, 3))

                        # Sum of squares
                        sos_true = np.sqrt(
                            (np.abs(y_true_copy)**2).sum(axis=1))
                        sos_pred = np.sqrt(
                            (np.abs(y_pred_copy)**2).sum(axis=1))
                        '''
                        # Normalization according to: extract_challenge_metrics.ipynb
                        sos_true_max = sos_true.max(axis = (1,2),keepdims = True)
                        sos_true_org = sos_true/sos_true_max
                        sos_pred_org = sos_pred/sos_true_max
                        # Normalization by normalzing with ref with max_ref and rec with max_rec, respectively
                        sos_true_max = sos_true.max(axis = (1,2),keepdims = True)
                        sos_true_mod = sos_true/sos_true_max
                        sos_pred_max = sos_pred.max(axis = (1,2),keepdims = True)
                        sos_pred_mod = sos_pred/sos_pred_max
                        '''
                        '''
                        # normalization by mean and std
                        std = sos_pred.std(axis=(1, 2), keepdims=True)
                        mean = sos_pred.mean(axis=(1, 2), keepdims=True)
                        sos_pred_std = (sos_pred-mean) / std
                        std = sos_true.std(axis=(1, 2), keepdims=True)
                        mean = sos_pred.mean(axis=(1, 2), keepdims=True)
                        sos_true_std = (sos_true-mean) / std
                        '''
                        '''
                        ssim, psnr, vif = metrics(sos_pred_org, sos_true_org)
                        ssim_mod, psnr_mod, vif_mod = metrics(sos_pred_mod, sos_true_mod)
                        '''
                        sos_true_max = sos_true.max(axis=(1, 2), keepdims=True)
                        sos_true_org = sos_true / sos_true_max
                        sos_pred_org = sos_pred / sos_true_max

                        ssim, psnr, vif = metrics(sos_pred, sos_true)
                        ssim_normed, psnr_normed, vif_normed = metrics(
                            sos_pred_org, sos_true_org)

                        verbose['ssim_' + phase].append(np.mean(ssim_normed))
                        verbose['psnr_' + phase].append(np.mean(psnr_normed))
                        verbose['vif_' + phase].append(np.mean(vif_normed))
                        '''
                        print("===Normalization according to: extract_challenge_metrics.ipynb===")
                        print("SSIM: ", verbose['ssim_'+phase][-1])
                        print("PSNR: ", verbose['psnr_'+phase][-1])
                        print("VIF: ",  verbose['vif_' +phase][-1])
                        print("===Normalization by normalzing with ref with max_ref and rec with max_rec, respectively===")
                        print("SSIM_mod: ", np.mean(ssim_mod))
                        print("PSNR_mod: ", np.mean(psnr_mod))
                        print("VIF_mod: ",  np.mean(vif_mod))
                        print("===Normalization by dividing by the standard deviation of ref and rec, respectively===")
                        '''
                        print("Epoch: ", epoch)
                        print("SSIM: ", np.mean(ssim))
                        print("PSNR: ", np.mean(psnr))
                        print("VIF: ", np.mean(vif))

                        print("SSIM_normed: ", verbose['ssim_' + phase][-1])
                        print("PSNR_normed: ", verbose['psnr_' + phase][-1])
                        print("VIF_normed: ", verbose['vif_' + phase][-1])
                        '''
                        if True: #verbose['vif_' + phase][-1] < 0.4:
                            plt.figure(figsize=(9, 6), dpi=150)
                            gs1 = gridspec.GridSpec(3, 2)
                            gs1.update(wspace=0.002, hspace=0.1)
                            plt.subplot(gs1[0])
                            plt.imshow(sos_true[0], cmap="gray")
                            plt.axis("off")
                            plt.subplot(gs1[1])
                            plt.imshow(sos_pred[0], cmap="gray")
                            plt.axis("off")
                            plt.show()
                            # plt.pause(10)
                            # plt.close()
                        '''
                        writer.add_scalar(
                            'SSIM/' + phase + '_epoch_' + str(epoch),
                            verbose['ssim_' + phase][-1], i)
                        writer.add_scalar(
                            'PSNR/' + phase + '_epoch_' + str(epoch),
                            verbose['psnr_' + phase][-1], i)
                        writer.add_scalar(
                            'VIF/' + phase + '_epoch_' + str(epoch),
                            verbose['vif_' + phase][-1], i)

                    print('Loss ' + phase + ': ', loss.item())

                    if phase == 'train':
                        if loss.item() < 2 * best_validation_loss:
                            loss.backward()
                            optimizer.step()

        # Calculate Averages
        psnr_mean = np.mean(verbose['psnr_valid'])
        ssim_mean = np.mean(verbose['ssim_valid'])
        vif_mean = np.mean(verbose['vif_valid'])
        validation_metrics = 0.2 * psnr_mean + 0.4 * ssim_mean + 0.4 * vif_mean

        valid_avg_loss_of_current_epoch = np.mean(verbose['loss_valid'])
        writer.add_scalar('AvgLoss/+train_epoch_' + str(epoch),
                          np.mean(verbose['loss_train']), epoch)
        writer.add_scalar('AvgLoss/+valid_epoch_' + str(epoch),
                          np.mean(verbose['loss_valid']), epoch)
        writer.add_scalar('AvgSSIM/+train_epoch_' + str(epoch),
                          np.mean(verbose['ssim_train']), epoch)
        writer.add_scalar('AvgSSIM/+valid_epoch_' + str(epoch), ssim_mean,
                          epoch)
        writer.add_scalar('AvgPSNR/+train_epoch_' + str(epoch),
                          np.mean(verbose['psnr_train']), epoch)
        writer.add_scalar('AvgPSNR/+valid_epoch_' + str(epoch), psnr_mean,
                          epoch)
        writer.add_scalar('AvgVIF/+train_epoch_' + str(epoch),
                          np.mean(verbose['vif_train']), epoch)
        writer.add_scalar('AvgVIF/+valid_epoch_' + str(epoch), vif_mean, epoch)

        verbose['loss_train'], verbose['loss_valid'], verbose['psnr_train'], verbose['psnr_valid'], \
        verbose['ssim_train'], verbose['ssim_valid'], verbose['vif_train'], verbose['vif_valid'] = ([] for i in
                                                                                                    range(8))

        # Save Networks/Checkpoints
        if best_validation_metrics > validation_metrics:
            best_validation_metrics = validation_metrics
            best_validation_loss = valid_avg_loss_of_current_epoch
            save_checkpoint(
                wnet, params.model_save_path, wnet_identifier, {
                    'epoch': epoch + 1,
                    'state_dict': wnet.state_dict(),
                    'best_validation_loss': best_validation_loss,
                    'optimizer': optimizer.state_dict(),
                }, True)
        else:
            save_checkpoint(
                wnet, params.model_save_path, wnet_identifier, {
                    'epoch': epoch + 1,
                    'state_dict': wnet.state_dict(),
                    'best_validation_loss': best_validation_loss,
                    'optimizer': optimizer.state_dict(),
                }, False)
Пример #17
0
 def parse(self, value):
     return DotDict(json.loads(value))
Пример #18
0
                    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(5, 3))
                    axes[0].imshow(gt.numpy(), cmap="gray")
                    axes[0].axis('off')
                    axes[1].imshow(test_image, cmap="gray")
                    axes[1].axis('off')
                    plt.savefig('img' + str(slice_idx) + str(fname[0]) +
                                '.png',
                                dpi=100)
                    print('saved slice ' + str(slice_idx) + ' to image')
                elif save_slice:
                    plt.imshow(test_image, cmap="gray")
                    plt.axis('off')
                    plt.savefig('img' + str(slice_idx) + str(fname[0]) +
                                '.png',
                                dpi=100)
                    print('saved slice ' + str(slice_idx) + ' to image')

                f.close()
                print('saved reconstruction to ' + str(fname[0]))
                tmp = torch.Tensor()


if __name__ == "__main__":
    params = config_wnet.config
    params['batch_size'] = 1
    params['slice_cut_val'] = (50, 50)
    params['norm'] = True
    params['architecture'] = 'iiiiii'
    net = r"..\BestVal=True_R5WNet_dense=1_iiiiii_msssim+vif_lr=0.0005_ep=50_complex=1_edgeModel=0(0)_date=24-07-2020_13-28-38.pth"
    run(DotDict(params), net, val=False)  # set val to True to use Val Set
Пример #19
0
def initopts():
    o = DotDict()
    o.stopwords_file = ""
    o.remove_puncuation = False
    o.remove_stop_words = False
    o.lemmatize_words = False
    o.num_replacement = "[NUM]"
    o.to_lowercase = False
    o.replace_nums = False  # Nums are important, since rumour may be lying about count
    o.eos = "[EOS]"
    o.add_eos = True
    o.returnNERvector = True
    o.returnDEPvector = True
    o.returnbiglettervector = True
    o.returnposvector = True
    return o
Пример #20
0
class Endpoint(Resource):
    CONST = DotDict(endpoint_const)
    RESOURCE = "/endpoints"

    def __init__(self, name, type, service_url='', maxTps=10, **kwargs):
        super().__init__()
        if type not in ['http', 'https']:
            raise Exception("endpoint_type should be either http or https")

        self.endpointConfig = {
            "endpointType": "SINGLE",
            "list": [{
                "url": service_url,
                "timeout": "1000",
                "attributes": []
            }]
        }
        self.endpointSecurity = {'enabled': False}
        self.name = name
        self.max_tps = maxTps
        self.type = type
        if kwargs:
            self._parse_json(kwargs)

    def is_secured(self):
        return self.endpointSecurity['enabled']

    def set_rest_client(self, client=RestClient()):
        self.client = client

    def set_security(self, security_type, username, password):
        if security_type not in Endpoint.CONST.SECURITY_TYPES.values():
            raise Exception(
                "Invalid security type, please proved one of follows {}".
                format(Endpoint.CONST.SECURITY_TYPES))
        self.endpointSecurity = {
            'enabled': True,
            'type': security_type,
            'username': username,
            'password': password
        }
        return self.endpointSecurity

    @property
    def _parsers(self):
        parsers = {
            # 'endpointConfig': self._parse_endpointConfig
        }
        return parsers

    def _parse_endpointConfig(self, value):
        endpoint_config_json = json.loads(value)
        self.endpointConfig = DotDict(endpoint_config_json)

    @property
    def service_url(self):
        return self.endpointConfig['list'][0]['url']

    @service_url.setter
    def service_url(self, value):
        self.endpointConfig['list'][0]['url'] = value

    def save(self):
        headers = {
            'Accept': 'application/json',
            'Content-Type': 'application/json'
        }
        data = self.to_json()
        res = self.client.session.post(Endpoint.get_endpoint(),
                                       data=json.dumps(data),
                                       verify=self.client.verify,
                                       headers=headers)
        print("Status code: {}".format(res.status_code))
        if res.status_code != 201:
            print(res.json())
            raise Exception("Fail to save the global endpoint via REST API")
        self._data = DotDict(res.json())
        self.id = self._data.id
        return self

    def delete(self):
        res = self.client.session.delete(Endpoint.get_endpoint() +
                                         "/{}".format(self.id),
                                         verify=self.client.verify)
        if res.status_code != 200:
            print("Warning Error while deleting the API {}\nERROR: {}".format(
                self.name, res.content))
        print("Status code: {}".format(res.status_code))

    def to_json(self):
        temp = {
            'name': self.name,
            'endpointConfig': self.endpointConfig,
            'endpointSecurity': self.endpointSecurity,
            'maxTps': self.max_tps,
            'type': self.type
        }
        return temp

    @staticmethod
    def delete_all(client=RestClient()):
        res = client.session.get(Endpoint.get_endpoint(), verify=client.verify)
        if not res.status_code == 200:
            print(res.json())
            raise Exception("Error getting APIs list")
        print("Status code: {}".format(res.status_code))
        apis_list = res.json()['list']
        for api in apis_list:
            res = client.session.delete(Endpoint.get_endpoint() +
                                        "/{}".format(api['id']),
                                        verify=client.verify)
            if res.status_code != 200:
                print("Warning Error while deleting the API {}\nERROR: {}".
                      format(api['name'], res.content))
            print("Status code: {}".format(res.status_code))

    @staticmethod
    def all():
        raise NotImplemented("Not implemented yet ...")
Пример #21
0
 def _parse_endpointConfig(self, value):
     endpoint_config_json = json.loads(value)
     self.endpointConfig = DotDict(endpoint_config_json)