def _parse_json(self, json_response): self._data = DotDict(json_response) if not self._data.id: # TODO: Should be able to pass non-preserved api JSON object as well ? print("WARN: Missing User ID for parsing data") self._data.id = None # raise Exception( # "JSON response should have id property for a valid Resource") for key, val in self._data.items(): self.__setattr__(key, val)
def __init__(self, hp, net_arch, loss_f, rank=0, world_size=1): self.hp = hp self.device = self.hp.model.device self.net = net_arch.to(self.device) self.rank = rank self.world_size = world_size if self.device != "cpu" and self.world_size != 0: self.net = DDP(self.net, device_ids=[self.rank]) self.input = None self.GT = None self.step = 0 self.epoch = -1 # init optimizer optimizer_mode = self.hp.train.optimizer.mode if optimizer_mode == "adam": self.optimizer = torch.optim.Adam( self.net.parameters(), **(self.hp.train.optimizer[optimizer_mode])) else: raise Exception("%s optimizer not supported" % optimizer_mode) # init loss self.loss_f = loss_f self.log = DotDict()
class Resource(object): """ Parent for API and Endpoint resources , implement generalize delete type method here """ BASE = "https://{host}:{port}/api/am/".format( host=default['connection']['hostname'], port=default['connection']['port']) VERSION = 1.0 APPLICATION = "publisher" RESOURCE = '' def __init__(self): self._data = None # raw Resource data self.id = None # Resource ID, None until persist via REST API, call save() will set ID self.client = None def delete(self): pass def _parse_json(self, json_response): self._data = DotDict(json_response) if not self._data.id: # TODO: Should be able to pass non-preserved api JSON object as well ? print("WARN: Missing API ID for parsing data") self._data.id = None # raise Exception( # "JSON response should have id property for a valid Resource") for key, val in self._data.items(): if key in self._parsers: self._parsers[key](val) continue self.__setattr__(key, val) @property def _parsers(self): raise NotImplemented() @classmethod def get_endpoint(cls): return cls.BASE + "{application}/v{api_version}".format( api_version=cls.VERSION, application=cls.APPLICATION) + cls.RESOURCE @classmethod def get(cls, resource_id, client=RestClient()): res = client.session.get(cls.get_endpoint() + "/{}".format(resource_id), verify=client.verify) if res.status_code != 200: print("Warning Error while getting the Resource {}\nERROR: {}". format(resource_id, res.content)) print("Status code: {}".format(res.status_code)) return cls(**res.json()) @staticmethod def all(): pass @staticmethod def delete_all(self): pass
def get_rotated_mnist_dataloader(cfg, split='train', filter_obj=None, batch_size=128, task_num=10, *args, **kwargs): d = DotDict() fuzzy = get_config_attr(cfg, 'EXTERNAL.OCL.FUZZY', default=0, mute=True) global _rmnist_loaders if not _rmnist_loaders: data = get_rotated_mnist(d) #train_loader, val_loader, test_loader = [CLDataLoader(elem, batch_size, train=t) \ # for elem, t in zip(data, [True, False, False])] loader_cls = CLDataLoader if not fuzzy else FuzzyCLDataLoader train_loader, val_loader, test_loader = [loader_cls(elem, batch_size, train=t) \ for elem, t in zip(data, [True, False, False])] _rmnist_loaders = train_loader, val_loader, test_loader else: train_loader, val_loader, test_loader = _rmnist_loaders if split == 'train': return train_loader[filter_obj[0]] elif split == 'val': return val_loader[filter_obj[0]] elif split == 'test': return test_loader[filter_obj[0]]
class Application(Resource): CONST = DotDict(application_const) RESOURCE = "/applications" APPLICATION = "store" def __init__(self, name, description, throttlingTier='', **kwargs): super().__init__() self.name = name self.description = description self.throttlingTier = throttlingTier self.subscriber = None self.permission = "[{\"groupId\" : 1000, \"permission\" : [\"READ\",\"UPDATE\"]},{\"groupId\" : 1001, \"permission\" : [\"READ\",\"UPDATE\"]}]" self.lifeCycleStatus = None self.keys = [] if kwargs: self._parse_json(kwargs) def save(self): headers = { 'Accept': 'application/json', 'Content-Type': 'application/json' } data = self.to_json() res = self.client.session.post(Application.get_endpoint(), data=json.dumps(data), verify=self.client.verify, headers=headers) print("Status code: {}".format(res.status_code)) if res.status_code != 201: print(res.json()) raise Exception("Fail to save the global endpoint via REST API") self._data = DotDict(res.json()) self.id = self._data.applicationId return self @staticmethod def delete_all(client=RestClient()): res = client.session.get(Application.get_endpoint(), verify=client.verify) if not res.status_code == 200: print(res.json()) raise Exception("Error getting Applications list") print("Status code: {}".format(res.status_code)) apps_list = res.json()['list'] for app in apps_list: res = client.session.delete(Application.get_endpoint() + "/{}".format(app['applicationId']), verify=client.verify) if res.status_code != 200: print("Warning Error while deleting the API {}\nERROR: {}".format(app['name'], res.content)) print("Status code: {}".format(res.status_code)) # TODO: Make this a generic to_json method for all Resources def to_json(self): temp = { 'name': self.name, 'description': self.description, 'throttlingTier': self.throttlingTier } return temp def set_rest_client(self, client=RestClient(APPLICATION)): self.client = client
def __init__(self, name, version, context, service_url=None, **kwargs): super().__init__() self.client = None # REST API Client self.id = None self._data = None self.name = name self.context = context self.version = version self.endpoint = DotDict() if service_url: inline_endpoint_name = "{}_{}".format(Endpoint.CONST.TYPES.SANDBOX, name) self.endpoint[Endpoint.CONST.TYPES.SANDBOX] = Endpoint( inline_endpoint_name, 'http', service_url) inline_endpoint_name = "{}_{}".format( Endpoint.CONST.TYPES.PRODUCTION, name) self.endpoint[Endpoint.CONST.TYPES.PRODUCTION] = Endpoint( inline_endpoint_name, 'http', service_url) if kwargs: self._parse_json(kwargs)
def save(self): headers = { 'Accept': 'application/json', 'Content-Type': 'application/json' } data = self.to_json() res = self.client.session.post(Application.get_endpoint(), data=json.dumps(data), verify=self.client.verify, headers=headers) print("Status code: {}".format(res.status_code)) if res.status_code != 201: print(res.json()) raise Exception("Fail to save the global endpoint via REST API") self._data = DotDict(res.json()) self.id = self._data.applicationId return self
def __init__(self, hp, net_arch, loss_f): self.hp = hp self.device = hp.model.device self.net = net_arch.to(self.device) self.input = None self.GT = None self.step = 0 self.epoch = -1 # init optimizer optimizer_mode = hp.train.optimizer.mode if optimizer_mode == "adam": self.optimizer = torch.optim.Adam( self.net.parameters(), **(hp.train.optimizer[optimizer_mode])) else: raise Exception("%s optimizer not supported" % optimizer_mode) # init loss self.loss_f = loss_f self.log = DotDict()
def get_split_mini_imagenet_dataloader(cfg, split='train', filter_obj=None, batch_size=128, *args, **kwargs): global _cache_mini_imagenet d = DotDict() fuzzy = get_config_attr(cfg, 'EXTERNAL.OCL.FUZZY', default=0, mute=True) if not _cache_mini_imagenet: data = get_miniimagenet(d) loader_cls = CLDataLoader if not fuzzy else FuzzyCLDataLoader train_loader, val_loader, test_loader = [loader_cls(elem, batch_size, train=t) \ for elem, t in zip(data, [True, False, False])] _cache_mini_imagenet = train_loader, val_loader, test_loader train_loader, val_loader, test_loader = _cache_mini_imagenet if split == 'train': return train_loader[filter_obj[0]] elif split == 'val': return val_loader[filter_obj[0]] elif split == 'test': return test_loader[filter_obj[0]]
def get_split_cifar100_dataloader(cfg, split='train', filter_obj=None, batch_size=128, *args, **kwargs): d = DotDict() fuzzy = get_config_attr(cfg, 'EXTERNAL.OCL.FUZZY', default=0, mute=True) global _cache_cifar100 if not _cache_cifar100: data = get_split_cifar100( d, cfg ) #ds_cifar10and100(batch_size=batch_size, num_workers=0, cfg=cfg, **kwargs) loader_cls = CLDataLoader if not fuzzy else FuzzyCLDataLoader train_loader, val_loader, test_loader = [loader_cls(elem, batch_size, train=t) \ for elem, t in zip(data, [True, False, False])] _cache_cifar100 = train_loader, val_loader, test_loader train_loader, val_loader, test_loader = _cache_cifar100 if split == 'train': return train_loader[filter_obj[0]] elif split == 'val': return val_loader[filter_obj[0]] elif split == 'test': return test_loader[filter_obj[0]]
class API(Resource): RESOURCE = "/apis" def __init__(self, name, version, context, service_url=None, **kwargs): super().__init__() self.client = None # REST API Client self.id = None self._data = None self.name = name self.context = context self.version = version self.endpoint = DotDict() if service_url: inline_endpoint_name = "{}_{}".format(Endpoint.CONST.TYPES.SANDBOX, name) self.endpoint[Endpoint.CONST.TYPES.SANDBOX] = Endpoint( inline_endpoint_name, 'http', service_url) inline_endpoint_name = "{}_{}".format( Endpoint.CONST.TYPES.PRODUCTION, name) self.endpoint[Endpoint.CONST.TYPES.PRODUCTION] = Endpoint( inline_endpoint_name, 'http', service_url) if kwargs: self._parse_json(kwargs) def set_rest_client(self, client=RestClient()): self.client = client def save(self): if self.id: print("WARN: API is already persist") return self res = self.client.session.post(API.get_endpoint(), json=self.to_json(), verify=self.client.verify) if res.status_code != 201: print(res.content) print(res.status_code) raise Exception("An error occurred while creating an API CODE: " + str(res.status_code)) self._parse_json(res.json()) print("Status code: {}".format(res.status_code)) return self def to_json(self): temp = { 'name': self.name, 'version': self.version, 'context': self.context, } keys = self._data.keys() for key, value in self.__dict__.items(): if key in keys: temp[key] = value endpoints = [] for epType, endpoint in self.endpoint.items(): serialize_endpoint = {"type": epType} if endpoint.id: serialize_endpoint['key'] = endpoint.id else: serialize_endpoint['inline'] = endpoint.to_json() endpoints.append(serialize_endpoint) if len(endpoints): temp['endpoint'] = endpoints return temp def _parse_endpoint(self, endpoint_json): for endpoint in endpoint_json: endpoint_json = endpoint.get('inline') or endpoint.get('key') self.endpoint[endpoint['type']] = Endpoint(**endpoint_json) def set_endpoint(self, endpoint_type, endpoint): if endpoint_type not in Endpoint.CONST.TYPES.values(): raise Exception("Endpoint type should be one of these {}".format( Endpoint.CONST.TYPES.values())) if type(endpoint) is not Endpoint: raise Exception("endpoint should be an instance of Endpoint") if endpoint_type != Endpoint.CONST.TYPES.INLINE and not endpoint.id: raise Exception( "Global endpoint should have persist before mapping it to an API" ) self.endpoint[endpoint_type] = endpoint self._data.endpoint = self.endpoint res = self.client.session.put(API.get_endpoint() + "/" + self.id, json=self.to_json(), verify=self.client.verify) if res.status_code != 200: print("Something went wrong when updating endpoints") print(res) return self def set_policies(self, policy_data): self.policies = self.policies if self.policies else [] if type(policy_data) is list: self.policies.extend(policy_data) else: self.policies.append(policy_data) self._data["policies"] = self.policies res = self.client.session.put(API.get_endpoint() + "/" + self.id, json=self.to_json(), verify=self.client.verify) if res.status_code != 200: print("Something went wrong when updating policies") print(res) return self def delete(self): res = self.client.session.delete(API.get_endpoint() + "/{}".format(self.id), verify=self.client.verify) if res.status_code != 200: print("Warning Error while deleting the API {}\nERROR: {}".format( self.name, res.content)) print("Status code: {}".format(res.status_code)) def change_lc(self, state): api_id = self.id data = {"action": state, "apiId": api_id} lcs = self.client.session.get(API.get_endpoint() + "/{apiId}/lifecycle".format( apiId=api_id)) if lcs.ok: lcs = lcs.json() available_transitions = [ current_state for current_state in lcs['availableTransitionBeanList'] if current_state['targetState'] == state ] if len(available_transitions) == 1: res = self.client.session.post(API.get_endpoint() + "/change-lifecycle", params=data, verify=self.client.verify) print("Status code: {}".format(res.status_code)) else: raise ("Invalid transition state valid ones are = {}".format( lcs['availableTransitionBeanList'])) else: raise ("Can't find Lifecycle for the api {}".format(api_id)) @property def _parsers(self): parsers = { # if need special parsing other than simply assigning as attribute 'endpoint': self._parse_endpoint } return parsers @staticmethod def delete_all(client=RestClient()): res = client.session.get(API.get_endpoint(), verify=client.verify) if not res.status_code == 200: print(res.json()) raise Exception("Error getting APIs list") print("Status code: {}".format(res.status_code)) apis_list = res.json()['list'] for api in apis_list: res = client.session.delete(API.get_endpoint() + "/{}".format(api['id']), verify=client.verify) if res.status_code != 200: print("Warning Error while deleting the API {}\nERROR: {}". format(api['name'], res.content)) print("Status code: {}".format(res.status_code)) @staticmethod def all(client=RestClient()): res = client.session.get(API.get_endpoint(), verify=client.verify) if not res.status_code == 200: print(res.json()) raise Exception("Error getting APIs list") print("Status code: {}".format(res.status_code)) return [API(**api) for api in res.json()['list']]
def train_loop(rank, hp, world_size=1): # reload hp hp = DotDict(hp) if hp.model.device.lower() == "cuda" and world_size != 0: setup(hp, rank, world_size) if rank != 0: logger = None writer = None else: # set logger logger = make_logger(hp) # set writer (tensorboard / wandb) writer = Writer(hp, hp.log.log_dir) hp_str = yaml.dump(hp.to_dict()) logger.info("Config:") logger.info(hp_str) if hp.data.train_dir == "" or hp.data.test_dir == "": logger.error("train or test data directory cannot be empty.") raise Exception("Please specify directories of data") logger.info("Set up train process") if hp.model.device.lower() == "cuda" and world_size != 0: hp.model.device = rank torch.cuda.set_device(rank) else: hp.model.device = hp.model.device.lower() # make dataloader if logger is not None: logger.info("Making train dataloader...") train_loader = create_dataloader(hp, DataloaderMode.train, rank, world_size) if logger is not None: logger.info("Making test dataloader...") test_loader = create_dataloader(hp, DataloaderMode.test, rank, world_size) # init Model net_arch = Net_arch(hp) loss_f = torch.nn.MSELoss() model = Model(hp, net_arch, loss_f, rank, world_size) # load training state if hp.load.resume_state_path is not None: model.load_training_state(logger) else: if logger is not None: logger.info("Starting new training run.") try: epoch_step = 1 if hp.data.divide_dataset_per_gpu else world_size for model.epoch in itertools.count(model.epoch + 1, epoch_step): if model.epoch > hp.train.num_iter: break train_model(hp, model, train_loader, writer, logger) if model.epoch % hp.log.chkpt_interval == 0: model.save_network(logger) model.save_training_state(logger) test_model(hp, model, test_loader, writer) cleanup() if logger is not None: logger.info("End of Train") except Exception as e: if logger is not None: logger.info("Exiting due to exception: %s" % e) traceback.print_exc() cleanup()
def _worker_init(hdf_file_name, devices, lock, generations, nn_class, params, player_change_callback=None): global _env_ import multiprocessing as mp import time import self_play from nn import NeuralNetWrapper from utils.proxies import AsyncBatchedProxy pid = mp.current_process().pid _env_ = DotDict({}) if not isinstance(nn_class, (list,tuple)): nn_class, generations, params = (nn_class,), (generations,), (params,) else: _env_.compare_models = True assert len(nn_class) == len(generations) == len(params) pytorch_device = devices[pid%len(devices)] players_params = {} models = {} for i in range(len(generations)): models[i] = nn_class[i](params[i]) if generations[i] != 0: models[i].load_parameters(generations[i]-(1 if len(nn_class)==1 else 0), to_device=pytorch_device) players_params[i] = params[i] if len(models) == 1: models[1] = models[0] players_params[1] = players_params[0] generations = [generations[0], generations[0]] players_params[0].nn.pytorch_device = pytorch_device players_params[1].nn.pytorch_device = pytorch_device # shuffle the players based on the pid, important for computing the Elo score if pid % 2 != 0: tmp = models[0] models[0] = models[1] models[1] = tmp tmp = players_params[0] players_params[0] = players_params[1] players_params[1] = players_params[0] generations = list(reversed(generations)) _env_.models = models _env_.players_params = DotDict(players_params) _env_.generations = generations _env_.params = players_params[0] _env_.player_change_callback = player_change_callback _env_.name = 'w%i' % pid _env_.hdf_file_name = hdf_file_name _env_.hdf_lock = lock logger.info("Worker %s uses device %s", _env_.name, _env_.params.nn.pytorch_device) loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) sp_params = _env_.params.self_play _env_.nn_wrapper = NeuralNetWrapper(None, _env_.params) _env_.nnet = AsyncBatchedProxy(_env_.nn_wrapper, batch_size=sp_params.nn_batch_size, timeout=sp_params.nn_batch_timeout, batch_builder=sp_params.nn_batch_builder, cache_size = 0 if _env_.compare_models else 400000) _env_.tasks = [] fut = asyncio.ensure_future(_env_.nnet.run(), loop=loop) fut.add_done_callback(_fut_cb) # re-raise exception if occured in nnet _env_.tasks.append(fut)
class User(object): def __init__(self, userName, password=None, **kwargs): self.auth = HTTPBasicAuth(default["admin"]["name"], default["admin"]["password"]) self.userName = userName self.password = password self.api_endpoint = "https://{host}:{port}/api/identity/scim2/v1.0/Users".format( host=default['connection']['hostname'], port=default['connection']['port']) self.verify = False if kwargs: self._parse_json(kwargs) def serialize(self): data = { "schemas": [], "name": { "familyName": "{}".format(self.userName), "givenName": "sam" }, "userName": "******".format(self.userName), "password": "******".format(self.userName), "emails": [{ "primary": True, "value": "{}@gmail.com".format(self.userName), "type": "home" }, { "value": "{}[email protected]".format(self.userName), "type": "work" }] } if self.password: data["password"] = self.password return data def _parse_json(self, json_response): self._data = DotDict(json_response) if not self._data.id: # TODO: Should be able to pass non-preserved api JSON object as well ? print("WARN: Missing User ID for parsing data") self._data.id = None # raise Exception( # "JSON response should have id property for a valid Resource") for key, val in self._data.items(): self.__setattr__(key, val) def save(self): response = requests.post(self.api_endpoint, json=self.serialize(), auth=self.auth, verify=self.verify) if not response.ok: print(response.content) print(response.status_code) raise Exception("An error occurred while creating an user: "******"{}/{}".format(self.api_endpoint, self.id), auth=self.auth, verify=self.verify) if not response.ok: print(response.content) print(response.status_code) raise Exception("An error occurred while creating an user: "******"https://{host}:{port}/api/identity/scim2/v1.0/Users".format( host=default['connection']['hostname'], port=default['connection']['port']) response = requests.get(api_endpoint, auth=HTTPBasicAuth( default["admin"]["name"], default["admin"]["password"]), verify=False) if not response.ok: print(response.content) print(response.status_code) raise Exception("An error occurred while getting all users: " + str(response.status_code)) return [User(**user) for user in response.json()['Resources']] @staticmethod def delete_all(): for user in User.all(): if user.userName.lower() == 'admin': print("WARN: Skip deleting admin user") continue print("DEBUG: deleting user {}".format(user.userName)) user.delete()
simple = DotDict({ "data_root": "data/_exp_", "hdf_file": "data/_exp_/sp_data.hdf", "tensorboard_log": "data/tboard/_exp_", "game": { "clazz": BoxesState, "init": partial(BoxesState.init_static_fields, ((3, 3),)), }, "self_play": { "num_games": 2000, "n_workers": 20, "games_per_workers": 25, "reuse_mcts_tree": True, "noise": (0.8, 0.25), # alpha, coeff "nn_batch_size": 48, "nn_batch_timeout": 0.05, "nn_batch_builder": nn_batch_builder, "pytorch_devices": ["cuda:1", "cuda:0"], # get_cuda_devices_list(), "mcts": { "mcts_num_read": 800, "mcts_cpuct": (1.25, 19652), # CPUCT, CPUCT_BASE "temperature": {0: 1.0, 12: 0.02}, "max_async_searches": 64, } }, "elo": { "hdf_file": "data/_exp_/elo_data.hdf", "n_games": 20, "n_workers": 10, "games_per_workers": 2, "self_play_override": { "reuse_mcts_tree": False, "noise": (0.0, 0.0), "mcts": { "mcts_num_read": 1200 } } }, "nn": { "model_class": SimpleNN, "pytorch_device": "cuda:0", "chkpts_filename": "data/_exp_/model_gen{}.pt", "train_params": { "pos_average": True, "symmetries": SymmetriesGenerator(), "nb_epochs": 10, "max_samples_per_gen":100*4096, # approx samples for 10 generations "train_split": 0.9, "train_batch_size": 4096, "val_batch_size": 4096, "lr_scheduler": GenerationLrScheduler({0: 1e-2, 20: 1e-3, 50: 1e-4}), "lr": 1e-2, "optimizer_params": { "momentum": 0.9, "weight_decay": 1e-4, } }, "model_parameters": None } })
def train_net(params): # Initialize Parameters params = DotDict(params) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") verbose = {} verbose['loss_train'], verbose['loss_valid'], verbose['psnr_train'], verbose['psnr_valid'], \ verbose['ssim_train'], verbose['ssim_valid'], verbose['vif_train'], verbose['vif_valid'] = ([] for i in range(8)) log_metrics = True ssim_module = SSIM() msssim_module = MSSSIM() vifLoss = VIFLoss(sigma_n_sq=0.4, data_range=1.) msssimLoss = MultiScaleSSIMLoss(data_range=1.) best_validation_metrics = 100 train_generator, val_generator = data_loaders(params) loaders = {"train": train_generator, "valid": val_generator} wnet_identifier = params.mask_URate[0:2] + "WNet_dense=" + str(int(params.dense)) + "_" + params.architecture + "_" \ + params.lossFunction + '_lr=' + str(params.lr) + '_ep=' + str(params.num_epochs) + '_complex=' \ + str(int(params.complex_net)) + '_' + 'edgeModel=' + str(int(params.edge_model)) \ + '(' + str(params.num_edge_slices) + ')_date=' + (datetime.now()).strftime("%d-%m-%Y_%H-%M-%S") if not os.path.isdir(params.model_save_path): os.mkdir(params.model_save_path) print("\n\nModel will be saved at:\n", params.model_save_path) print("WNet ID: ", wnet_identifier) wnet, optimizer, best_validation_loss, preTrainedEpochs = generate_model( params, device) # data = (iter(train_generator)).next() # Adding writer for tensorboard. Also start tensorboard, which tries to access logs in the runs directory writer = init_tensorboard(iter(train_generator), wnet, wnet_identifier, device) for epoch in trange(preTrainedEpochs, params.num_epochs): for phase in ['train', 'valid']: if phase == 'train': wnet.train() else: wnet.eval() for i, data in enumerate(loaders[phase]): # for i in range(10000): x, y_true, _, _, fname, slice_num = data x, y_true = x.to(device, dtype=torch.float), y_true.to( device, dtype=torch.float) optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): y_pred = wnet(x) if params.lossFunction == 'mse': loss = F.mse_loss(y_pred, y_true) elif params.lossFunction == 'l1': loss = F.l1_loss(y_pred, y_true) elif params.lossFunction == 'ssim': # standard SSIM loss = 0.16 * F.l1_loss(y_pred, y_true) + 0.84 * ( 1 - ssim_module(y_pred, y_true)) elif params.lossFunction == 'msssim': # loss = 0.16 * F.l1_loss(y_pred, y_true) + 0.84 * (1 - msssim_module(y_pred, y_true)) prediction_abs = torch.sqrt( torch.square(y_pred[:, 0::2]) + torch.square(y_pred[:, 1::2])) target_abs = torch.sqrt( torch.square(y_true[:, 0::2]) + torch.square(y_true[:, 1::2])) prediction_abs_flat = (torch.flatten( prediction_abs, start_dim=0, end_dim=1)).unsqueeze(1) target_abs_flat = (torch.flatten( target_abs, start_dim=0, end_dim=1)).unsqueeze(1) loss = msssimLoss(prediction_abs_flat, target_abs_flat) elif params.lossFunction == 'vif': prediction_abs = torch.sqrt( torch.square(y_pred[:, 0::2]) + torch.square(y_pred[:, 1::2])) target_abs = torch.sqrt( torch.square(y_true[:, 0::2]) + torch.square(y_true[:, 1::2])) prediction_abs_flat = (torch.flatten( prediction_abs, start_dim=0, end_dim=1)).unsqueeze(1) target_abs_flat = (torch.flatten( target_abs, start_dim=0, end_dim=1)).unsqueeze(1) loss = vifLoss(prediction_abs_flat, target_abs_flat) elif params.lossFunction == 'mse+vif': prediction_abs = torch.sqrt( torch.square(y_pred[:, 0::2]) + torch.square(y_pred[:, 1::2])).to(device) target_abs = torch.sqrt( torch.square(y_true[:, 0::2]) + torch.square(y_true[:, 1::2])).to(device) prediction_abs_flat = (torch.flatten( prediction_abs, start_dim=0, end_dim=1)).unsqueeze(1) target_abs_flat = (torch.flatten( target_abs, start_dim=0, end_dim=1)).unsqueeze(1) loss = 0.15 * F.mse_loss( prediction_abs_flat, target_abs_flat) + 0.85 * vifLoss( prediction_abs_flat, target_abs_flat) elif params.lossFunction == 'l1+vif': prediction_abs = torch.sqrt( torch.square(y_pred[:, 0::2]) + torch.square(y_pred[:, 1::2])).to(device) target_abs = torch.sqrt( torch.square(y_true[:, 0::2]) + torch.square(y_true[:, 1::2])).to(device) prediction_abs_flat = (torch.flatten( prediction_abs, start_dim=0, end_dim=1)).unsqueeze(1) target_abs_flat = (torch.flatten( target_abs, start_dim=0, end_dim=1)).unsqueeze(1) loss = 0.146 * F.l1_loss( y_pred, y_true) + 0.854 * vifLoss( prediction_abs_flat, target_abs_flat) elif params.lossFunction == 'msssim+vif': prediction_abs = torch.sqrt( torch.square(y_pred[:, 0::2]) + torch.square(y_pred[:, 1::2])).to(device) target_abs = torch.sqrt( torch.square(y_true[:, 0::2]) + torch.square(y_true[:, 1::2])).to(device) prediction_abs_flat = (torch.flatten( prediction_abs, start_dim=0, end_dim=1)).unsqueeze(1) target_abs_flat = (torch.flatten( target_abs, start_dim=0, end_dim=1)).unsqueeze(1) loss = 0.66 * msssimLoss( prediction_abs_flat, target_abs_flat) + 0.33 * vifLoss( prediction_abs_flat, target_abs_flat) if not math.isnan(loss.item()) and loss.item( ) < 2 * best_validation_loss: # avoid nan/spike values verbose['loss_' + phase].append(loss.item()) writer.add_scalar( 'Loss/' + phase + '_epoch_' + str(epoch), loss.item(), i) if log_metrics and ( (i % params.verbose_gap == 0) or (phase == 'valid' and epoch > params.verbose_delay)): y_true_copy = y_true.detach().cpu().numpy() y_pred_copy = y_pred.detach().cpu().numpy() y_true_copy = y_true_copy[:, :: 2, :, :] + 1j * y_true_copy[:, 1:: 2, :, :] y_pred_copy = y_pred_copy[:, :: 2, :, :] + 1j * y_pred_copy[:, 1:: 2, :, :] if params.architecture[-1] == 'k': # transform kspace to image domain y_true_copy = np.fft.ifft2(y_true_copy, axes=(2, 3)) y_pred_copy = np.fft.ifft2(y_pred_copy, axes=(2, 3)) # Sum of squares sos_true = np.sqrt( (np.abs(y_true_copy)**2).sum(axis=1)) sos_pred = np.sqrt( (np.abs(y_pred_copy)**2).sum(axis=1)) ''' # Normalization according to: extract_challenge_metrics.ipynb sos_true_max = sos_true.max(axis = (1,2),keepdims = True) sos_true_org = sos_true/sos_true_max sos_pred_org = sos_pred/sos_true_max # Normalization by normalzing with ref with max_ref and rec with max_rec, respectively sos_true_max = sos_true.max(axis = (1,2),keepdims = True) sos_true_mod = sos_true/sos_true_max sos_pred_max = sos_pred.max(axis = (1,2),keepdims = True) sos_pred_mod = sos_pred/sos_pred_max ''' ''' # normalization by mean and std std = sos_pred.std(axis=(1, 2), keepdims=True) mean = sos_pred.mean(axis=(1, 2), keepdims=True) sos_pred_std = (sos_pred-mean) / std std = sos_true.std(axis=(1, 2), keepdims=True) mean = sos_pred.mean(axis=(1, 2), keepdims=True) sos_true_std = (sos_true-mean) / std ''' ''' ssim, psnr, vif = metrics(sos_pred_org, sos_true_org) ssim_mod, psnr_mod, vif_mod = metrics(sos_pred_mod, sos_true_mod) ''' sos_true_max = sos_true.max(axis=(1, 2), keepdims=True) sos_true_org = sos_true / sos_true_max sos_pred_org = sos_pred / sos_true_max ssim, psnr, vif = metrics(sos_pred, sos_true) ssim_normed, psnr_normed, vif_normed = metrics( sos_pred_org, sos_true_org) verbose['ssim_' + phase].append(np.mean(ssim_normed)) verbose['psnr_' + phase].append(np.mean(psnr_normed)) verbose['vif_' + phase].append(np.mean(vif_normed)) ''' print("===Normalization according to: extract_challenge_metrics.ipynb===") print("SSIM: ", verbose['ssim_'+phase][-1]) print("PSNR: ", verbose['psnr_'+phase][-1]) print("VIF: ", verbose['vif_' +phase][-1]) print("===Normalization by normalzing with ref with max_ref and rec with max_rec, respectively===") print("SSIM_mod: ", np.mean(ssim_mod)) print("PSNR_mod: ", np.mean(psnr_mod)) print("VIF_mod: ", np.mean(vif_mod)) print("===Normalization by dividing by the standard deviation of ref and rec, respectively===") ''' print("Epoch: ", epoch) print("SSIM: ", np.mean(ssim)) print("PSNR: ", np.mean(psnr)) print("VIF: ", np.mean(vif)) print("SSIM_normed: ", verbose['ssim_' + phase][-1]) print("PSNR_normed: ", verbose['psnr_' + phase][-1]) print("VIF_normed: ", verbose['vif_' + phase][-1]) ''' if True: #verbose['vif_' + phase][-1] < 0.4: plt.figure(figsize=(9, 6), dpi=150) gs1 = gridspec.GridSpec(3, 2) gs1.update(wspace=0.002, hspace=0.1) plt.subplot(gs1[0]) plt.imshow(sos_true[0], cmap="gray") plt.axis("off") plt.subplot(gs1[1]) plt.imshow(sos_pred[0], cmap="gray") plt.axis("off") plt.show() # plt.pause(10) # plt.close() ''' writer.add_scalar( 'SSIM/' + phase + '_epoch_' + str(epoch), verbose['ssim_' + phase][-1], i) writer.add_scalar( 'PSNR/' + phase + '_epoch_' + str(epoch), verbose['psnr_' + phase][-1], i) writer.add_scalar( 'VIF/' + phase + '_epoch_' + str(epoch), verbose['vif_' + phase][-1], i) print('Loss ' + phase + ': ', loss.item()) if phase == 'train': if loss.item() < 2 * best_validation_loss: loss.backward() optimizer.step() # Calculate Averages psnr_mean = np.mean(verbose['psnr_valid']) ssim_mean = np.mean(verbose['ssim_valid']) vif_mean = np.mean(verbose['vif_valid']) validation_metrics = 0.2 * psnr_mean + 0.4 * ssim_mean + 0.4 * vif_mean valid_avg_loss_of_current_epoch = np.mean(verbose['loss_valid']) writer.add_scalar('AvgLoss/+train_epoch_' + str(epoch), np.mean(verbose['loss_train']), epoch) writer.add_scalar('AvgLoss/+valid_epoch_' + str(epoch), np.mean(verbose['loss_valid']), epoch) writer.add_scalar('AvgSSIM/+train_epoch_' + str(epoch), np.mean(verbose['ssim_train']), epoch) writer.add_scalar('AvgSSIM/+valid_epoch_' + str(epoch), ssim_mean, epoch) writer.add_scalar('AvgPSNR/+train_epoch_' + str(epoch), np.mean(verbose['psnr_train']), epoch) writer.add_scalar('AvgPSNR/+valid_epoch_' + str(epoch), psnr_mean, epoch) writer.add_scalar('AvgVIF/+train_epoch_' + str(epoch), np.mean(verbose['vif_train']), epoch) writer.add_scalar('AvgVIF/+valid_epoch_' + str(epoch), vif_mean, epoch) verbose['loss_train'], verbose['loss_valid'], verbose['psnr_train'], verbose['psnr_valid'], \ verbose['ssim_train'], verbose['ssim_valid'], verbose['vif_train'], verbose['vif_valid'] = ([] for i in range(8)) # Save Networks/Checkpoints if best_validation_metrics > validation_metrics: best_validation_metrics = validation_metrics best_validation_loss = valid_avg_loss_of_current_epoch save_checkpoint( wnet, params.model_save_path, wnet_identifier, { 'epoch': epoch + 1, 'state_dict': wnet.state_dict(), 'best_validation_loss': best_validation_loss, 'optimizer': optimizer.state_dict(), }, True) else: save_checkpoint( wnet, params.model_save_path, wnet_identifier, { 'epoch': epoch + 1, 'state_dict': wnet.state_dict(), 'best_validation_loss': best_validation_loss, 'optimizer': optimizer.state_dict(), }, False)
def parse(self, value): return DotDict(json.loads(value))
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(5, 3)) axes[0].imshow(gt.numpy(), cmap="gray") axes[0].axis('off') axes[1].imshow(test_image, cmap="gray") axes[1].axis('off') plt.savefig('img' + str(slice_idx) + str(fname[0]) + '.png', dpi=100) print('saved slice ' + str(slice_idx) + ' to image') elif save_slice: plt.imshow(test_image, cmap="gray") plt.axis('off') plt.savefig('img' + str(slice_idx) + str(fname[0]) + '.png', dpi=100) print('saved slice ' + str(slice_idx) + ' to image') f.close() print('saved reconstruction to ' + str(fname[0])) tmp = torch.Tensor() if __name__ == "__main__": params = config_wnet.config params['batch_size'] = 1 params['slice_cut_val'] = (50, 50) params['norm'] = True params['architecture'] = 'iiiiii' net = r"..\BestVal=True_R5WNet_dense=1_iiiiii_msssim+vif_lr=0.0005_ep=50_complex=1_edgeModel=0(0)_date=24-07-2020_13-28-38.pth" run(DotDict(params), net, val=False) # set val to True to use Val Set
def initopts(): o = DotDict() o.stopwords_file = "" o.remove_puncuation = False o.remove_stop_words = False o.lemmatize_words = False o.num_replacement = "[NUM]" o.to_lowercase = False o.replace_nums = False # Nums are important, since rumour may be lying about count o.eos = "[EOS]" o.add_eos = True o.returnNERvector = True o.returnDEPvector = True o.returnbiglettervector = True o.returnposvector = True return o
class Endpoint(Resource): CONST = DotDict(endpoint_const) RESOURCE = "/endpoints" def __init__(self, name, type, service_url='', maxTps=10, **kwargs): super().__init__() if type not in ['http', 'https']: raise Exception("endpoint_type should be either http or https") self.endpointConfig = { "endpointType": "SINGLE", "list": [{ "url": service_url, "timeout": "1000", "attributes": [] }] } self.endpointSecurity = {'enabled': False} self.name = name self.max_tps = maxTps self.type = type if kwargs: self._parse_json(kwargs) def is_secured(self): return self.endpointSecurity['enabled'] def set_rest_client(self, client=RestClient()): self.client = client def set_security(self, security_type, username, password): if security_type not in Endpoint.CONST.SECURITY_TYPES.values(): raise Exception( "Invalid security type, please proved one of follows {}". format(Endpoint.CONST.SECURITY_TYPES)) self.endpointSecurity = { 'enabled': True, 'type': security_type, 'username': username, 'password': password } return self.endpointSecurity @property def _parsers(self): parsers = { # 'endpointConfig': self._parse_endpointConfig } return parsers def _parse_endpointConfig(self, value): endpoint_config_json = json.loads(value) self.endpointConfig = DotDict(endpoint_config_json) @property def service_url(self): return self.endpointConfig['list'][0]['url'] @service_url.setter def service_url(self, value): self.endpointConfig['list'][0]['url'] = value def save(self): headers = { 'Accept': 'application/json', 'Content-Type': 'application/json' } data = self.to_json() res = self.client.session.post(Endpoint.get_endpoint(), data=json.dumps(data), verify=self.client.verify, headers=headers) print("Status code: {}".format(res.status_code)) if res.status_code != 201: print(res.json()) raise Exception("Fail to save the global endpoint via REST API") self._data = DotDict(res.json()) self.id = self._data.id return self def delete(self): res = self.client.session.delete(Endpoint.get_endpoint() + "/{}".format(self.id), verify=self.client.verify) if res.status_code != 200: print("Warning Error while deleting the API {}\nERROR: {}".format( self.name, res.content)) print("Status code: {}".format(res.status_code)) def to_json(self): temp = { 'name': self.name, 'endpointConfig': self.endpointConfig, 'endpointSecurity': self.endpointSecurity, 'maxTps': self.max_tps, 'type': self.type } return temp @staticmethod def delete_all(client=RestClient()): res = client.session.get(Endpoint.get_endpoint(), verify=client.verify) if not res.status_code == 200: print(res.json()) raise Exception("Error getting APIs list") print("Status code: {}".format(res.status_code)) apis_list = res.json()['list'] for api in apis_list: res = client.session.delete(Endpoint.get_endpoint() + "/{}".format(api['id']), verify=client.verify) if res.status_code != 200: print("Warning Error while deleting the API {}\nERROR: {}". format(api['name'], res.content)) print("Status code: {}".format(res.status_code)) @staticmethod def all(): raise NotImplemented("Not implemented yet ...")
def _parse_endpointConfig(self, value): endpoint_config_json = json.loads(value) self.endpointConfig = DotDict(endpoint_config_json)