Exemple #1
0
    def __init__(self):
        super().__init__()

        logging.info("Dataset: QoENFLX")
        dataset_path = Config().data.data_path + '/QoENFLX/VideoATLAS/'
        db_files = os.listdir(dataset_path)
        db_files.sort(key=lambda var: [
            int(x) if x.isdigit() else x
            for x in re.findall(r'[^0-9]|[0-9]+', var)
        ])
        Nvideos = len(db_files)

        pre_load_train_test_data_LIVE_Netflix = sio.loadmat(
            Config().data.data_path +
            '/QoENFLX/TrainingMatrix_LIVENetflix_1000_trials.mat'
        )['TrainingMatrix_LIVENetflix_1000_trials']

        # randomly pick a trial out of the 1000
        nt_rand = np.random.choice(
            np.shape(pre_load_train_test_data_LIVE_Netflix)[1], 1)
        n_train = [
            ind for ind in range(0, Nvideos)
            if pre_load_train_test_data_LIVE_Netflix[ind, nt_rand] == 1
        ]
        n_test = [
            ind for ind in range(0, Nvideos)
            if pre_load_train_test_data_LIVE_Netflix[ind, nt_rand] == 0
        ]

        X = np.zeros((len(db_files), len(FEATURE_NAMES)))
        y = np.zeros((len(db_files), 1))

        feature_labels = list()
        for typ in FEATURE_NAMES:
            if typ == "VQA":
                feature_labels.append('STRRED' + "_" + 'mean')
            elif typ == "R$_1$":
                feature_labels.append("ds_norm")
            elif typ == "R$_2$":
                feature_labels.append("ns")
            elif typ == "M":
                feature_labels.append("tsl_norm")
            else:
                feature_labels.append("lt_norm")

        for i, f in enumerate(db_files):
            data = sio.loadmat(dataset_path + f)
            for feat_cnt, feat in enumerate(feature_labels):
                X[i, feat_cnt] = data[feat]
            y[i] = data["final_subj_score"]

        X_train_before_scaling = X[n_train, :]
        X_test_before_scaling = X[n_test, :]
        y_train = y[n_train]
        y_test = y[n_test]

        self.trainset = copy.deepcopy(
            np.concatenate((X_train_before_scaling, y_train), axis=1))
        self.testset = copy.deepcopy(
            np.concatenate((X_test_before_scaling, y_test), axis=1))
Exemple #2
0
 def get_state(self):
     """ Get state for agent. """
     if hasattr(Config().server,
                'synchronous') and not Config().server.synchronous:
         return self.new_state
     else:
         return np.squeeze(self.new_state.reshape(1, -1))
Exemple #3
0
    async def start_agent(self) -> None:
        """ Startup function for agent. """

        await asyncio.sleep(5)
        logging.info("[RL Agent] Contacting the central server.")

        self.sio = socketio.AsyncClient(reconnection=True)
        self.sio.register_namespace(
            RLAgentEvents(namespace='/', plato_rl_agent=self))

        uri = ""
        if hasattr(Config().server, 'use_https'):
            uri = 'https://{}'.format(Config().server.address)
        else:
            uri = 'http://{}'.format(Config().server.address)

        uri = '{}:{}'.format(uri, Config().server.port)

        logging.info("[RL Agent] Connecting to the server at %s.", uri)
        await self.sio.connect(uri)
        await self.sio.emit('agent_alive', {
            'agent': self.agent,
            'current_rl_episode': self.current_episode
        })

        logging.info("[RL Agent] Waiting to be updated with new state.")
        await self.sio.wait()
Exemple #4
0
 def get_record_items_values(self):
     """Get values will be recorded in result csv file."""
     return {
         'global_round':
         self.current_global_round,
         'round':
         self.current_round,
         'accuracy':
         self.accuracy * 100,
         'average_accuracy':
         self.average_accuracy * 100,
         'edge_agg_num':
         Config().algorithm.local_rounds,
         'local_epoch_num':
         Config().trainer.epochs,
         'elapsed_time':
         self.wall_time - self.initial_wall_time,
         'comm_time':
         max([report.comm_time for (report, __, __) in self.updates]),
         'round_time':
         max([
             report.training_time + report.comm_time
             for (report, __, __) in self.updates
         ]),
     }
Exemple #5
0
    def __init__(self, model=None):
        """Initializing the trainer with the provided model.

        Arguments:
        client_id: The ID of the client using this trainer (optional).
        model: The model to train.
        """
        super().__init__()

        if hasattr(Config().trainer, 'cpuonly') and Config().trainer.cpuonly:
            mindspore.context.set_context(mode=mindspore.context.PYNATIVE_MODE,
                                          device_target='CPU')
        else:
            mindspore.context.set_context(mode=mindspore.context.PYNATIVE_MODE,
                                          device_target='GPU')

        if model is None:
            self.model = models_registry.get()

        # Initializing the loss criterion
        loss_criterion = SoftmaxCrossEntropyWithLogits(sparse=True,
                                                       reduction='mean')

        # Initializing the optimizer
        optimizer = nn.Momentum(self.model.trainable_params(),
                                Config().trainer.learning_rate,
                                Config().trainer.momentum)

        self.mindspore_model = mindspore.Model(
            self.model,
            loss_criterion,
            optimizer,
            metrics={"Accuracy": Accuracy()})
Exemple #6
0
    def __init__(self, datasource, client_id, testing):
        super().__init__()

        self.client_id = client_id

        np.random.seed(self.random_seed)

        # obtain the dataset information
        if testing:
            dataset = datasource.get_test_set()
        else:
            dataset = datasource.get_train_set()

        # The list of labels (targets) for all the examples
        self.targets_list = datasource.targets

        self.dataset_size = len(dataset)

        indices = list(range(self.dataset_size))

        np.random.shuffle(indices)

        # Concentration parameter to be used in the Dirichlet distribution
        concentration = Config().data.client_quantity_concentration if hasattr(
            Config().data, 'client_quantity_concentration') else 1.0

        min_partition_size = Config().data.min_partition_size
        total_clients = Config().clients.total_clients

        self.subset_indices = self.sample_quantity_skew(
            dataset_indices=indices,
            dataset_size=self.dataset_size,
            min_partition_size=min_partition_size,
            concentration=concentration,
            num_clients=total_clients)[client_id]
Exemple #7
0
    async def wrap_up_processing_reports(self):
        """Wrap up processing the reports with any additional work."""
        if self.do_personalization_test:
            if hasattr(Config(), 'results'):
                new_row = []
                for item in self.recorded_items:
                    item_value = {
                        'round':
                        self.current_round,
                        'accuracy':
                        self.accuracy * 100,
                        'personalization_accuracy':
                        self.personalization_accuracy * 100,
                        'training_time':
                        self.training_time,
                        'round_time':
                        time.perf_counter() - self.round_start_time
                    }[item]
                    new_row.append(item_value)

                result_csv_file = Config().result_dir + 'result.csv'

                csv_processor.write_csv(result_csv_file, new_row)

            self.do_personalization_test = False

        else:
            self.training_time = max(
                [report.training_time for (report, __) in self.updates])
Exemple #8
0
 def get_done(self):
     """ Get done condition for agent. """
     if Config().algorithm.mode == 'train' and self.current_step >= Config(
     ).algorithm.steps_per_episode:
         logging.info("[RL Agent] Episode #%d ended.", self.current_episode)
         return True
     return False
Exemple #9
0
    def __init__(self, datasource, client_id, testing):
        super().__init__()

        # Different clients should have a different bias across the labels
        np.random.seed(self.random_seed * int(client_id))

        self.partition_size = Config().data.partition_size

        # Concentration parameter to be used in the Dirichlet distribution
        concentration = Config().data.concentration if hasattr(
            Config().data, 'concentration') else 1.0

        if testing:
            target_list = datasource.get_test_set().targets
        else:
            # The list of labels (targets) for all the examples
            target_list = datasource.targets()

        class_list = datasource.classes()

        target_proportions = np.random.dirichlet(
            np.repeat(concentration, len(class_list)))

        if np.isnan(np.sum(target_proportions)):
            target_proportions = np.repeat(0, len(class_list))
            target_proportions[random.randint(0, len(class_list) - 1)] = 1

        self.sample_weights = target_proportions[target_list]
Exemple #10
0
    def get_model(model_type):
        """Obtaining an instance of the RegNet model."""

        # If True, will return a RegNet model pre-trained on ImageNet
        pretrained = Config().trainer.pretrained if hasattr(
            Config().trainer, 'pretrained') else False

        if model_type == 'regnet_x_400mf':
            return torchvision.models.regnet_x_400mf(pretrained=pretrained)
        if model_type == 'regnet_x_800mf':
            return torchvision.models.regnet_x_800mf(pretrained=pretrained)
        if model_type == 'regnet_x_1_6gf':
            return torchvision.models.regnet_x_1_6gf(pretrained=pretrained)
        if model_type == 'regnet_x_3_2gf':
            return torchvision.models.regnet_x_3_2gf(pretrained=pretrained)
        if model_type == 'regnet_x_8gf':
            return torchvision.models.regnet_x_8gf(pretrained=pretrained)
        if model_type == 'regnet_x_16gf':
            return torchvision.models.regnet_x_16gf(pretrained=pretrained)
        if model_type == 'regnet_x_32gf':
            return torchvision.models.regnet_x_32gf(pretrained=pretrained)
        if model_type == 'regnet_y_400mf':
            return torchvision.models.regnet_y_400mf(pretrained=pretrained)
        if model_type == 'regnet_y_800mf':
            return torchvision.models.regnet_y_800mf(pretrained=pretrained)
        if model_type == 'regnet_y_1_6gf':
            return torchvision.models.regnet_y_1_6gf(pretrained=pretrained)
        if model_type == 'regnet_y_3_2gf':
            return torchvision.models.regnet_y_3_2gf(pretrained=pretrained)
        if model_type == 'regnet_y_8gf':
            return torchvision.models.regnet_y_8gf(pretrained=pretrained)
        if model_type == 'regnet_y_16gf':
            return torchvision.models.regnet_y_16gf(pretrained=pretrained)
        if model_type == 'regnet_y_32gf':
            return torchvision.models.regnet_y_32gf(pretrained=pretrained)
Exemple #11
0
    def save_model(self, filename=None, location=None):
        """Saving the model to a file."""
        model_dir = Config(
        ).params['model_dir'] if location is None else location
        model_name = Config().trainer.model_name

        try:
            if not os.path.exists(model_dir):
                os.makedirs(model_dir)
        except FileExistsError:
            pass

        if filename is not None:
            model_path = f'{model_dir}/{filename}'
        else:
            model_path = f'{model_dir}/{model_name}.pth'

        torch.save(self.model.state_dict(), model_path)

        if self.client_id == 0:
            logging.info("[Server #%d] Model saved to %s.", os.getpid(),
                         model_path)
        else:
            logging.info("[Client #%d] Model saved to %s.", self.client_id,
                         model_path)
Exemple #12
0
def get():
    """Get the model with the provided name."""
    model_name = Config().trainer.model_name
    model_type = model_name.split('_')[0]
    model = None

    if model_name == 'yolov5':
        from plato.models import yolo
        return yolo.Model.get_model()

    if model_name == 'HuggingFace_CausalLM':
        from transformers import AutoModelForCausalLM, AutoConfig

        model_checkpoint = Config().trainer.model_checkpoint
        config_kwargs = {
            "cache_dir": None,
            "revision": 'main',
            "use_auth_token": None,
        }
        config = AutoConfig.from_pretrained(model_checkpoint, **config_kwargs)
        return AutoModelForCausalLM.from_pretrained(
            model_checkpoint, config=config, cache_dir='./models/huggingface')

    else:
        for name, registered_model in registered_models.items():
            if name.startswith(model_type):
                model = registered_model.get_model(model_name)

    if model is None:
        raise ValueError('No such model: {}'.format(model_name))

    return model
Exemple #13
0
    async def periodic_task(self):
        """ A periodic task that is executed from time to time, determined by
        'server:periodic_interval' in the configuration. """
        # Call the async function that defines a customized periodic task, if any
        _task = getattr(self, "customize_periodic_task", None)
        if callable(_task):
            await self.customize_periodic_task()

        # If we are operating in asynchronous mode, aggregate the model updates received so far.
        if hasattr(Config().server,
                   'synchronous') and not Config().server.synchronous:
            if len(self.updates) > 0:
                logging.info(
                    "[Server #%d] %d client reports received in asynchronous mode. Processing.",
                    os.getpid(), len(self.updates))
                if self.action_applied and not self.clients_selected:
                    await self.select_clients()
                    self.clients_selected = True
                if self.action_applied and self.clients_selected:
                    await self.process_reports()
                    await self.wrap_up()
                    self.action_applied = False
                    self.clients_selected = False
            else:
                logging.info(
                    "[Server #%d] No client reports have been received. Nothing to process."
                )
Exemple #14
0
    def __init__(self):
        super().__init__()

        (ds_train, ds_test), ds_info = tfds.load(
            'fashion_mnist',
            split=['train', 'test'],
            shuffle_files=True,
            as_supervised=True,
            with_info=True,
        )

        ds_train = ds_train.map(
            DataSource.normalize_img,
            num_parallel_calls=tf.data.experimental.AUTOTUNE)
        ds_train = ds_train.cache()
        ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
        ds_train = ds_train.batch(Config().trainer.batch_size)
        ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

        ds_test = ds_test.map(DataSource.normalize_img,
                              num_parallel_calls=tf.data.experimental.AUTOTUNE)
        ds_test = ds_test.batch(Config().trainer.batch_size)
        ds_test = ds_test.cache()
        ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)

        self.trainset = ds_train
        self.testset = ds_test
Exemple #15
0
def read_csv_to_dict(result_csv_file: str) -> Dict[str, List]:
    """Read a CSV file and write the values that need to be plotted
    into a dictionary."""
    result_dict: Dict[str, List] = {}

    plot_pairs = Config().results.plot
    plot_pairs = [x.strip() for x in plot_pairs.split(',')]

    for pairs in plot_pairs:
        pair = [x.strip() for x in pairs.split('-')]
        for item in pair:
            if item not in result_dict:
                result_dict[item] = []

    with open(result_csv_file, 'r') as f:
        reader = csv.DictReader(f)
        for row in reader:
            for item in result_dict:
                if item in (
                        'round',
                        'global_round',
                        'local_epochs',
                ):
                    result_dict[item].append(int(row[item]))
                else:
                    result_dict[item].append(float(row[item]))

    return result_dict
Exemple #16
0
 def process_server_response(self, server_response):
     """Additional client-specific processing on the server response."""
     if 'local_epoch_num' in server_response:
         # Update the number of local epochs
         local_epoch_num = server_response['local_epoch_num']
         Config().trainer = Config().trainer._replace(
             epochs=local_epoch_num)
Exemple #17
0
def plot_figures_from_dict(result_csv_file: str, result_dir: str):
    """Plot figures with dictionary of results."""
    result_dict = read_csv_to_dict(result_csv_file)

    plot_pairs = Config().results.plot
    plot_pairs = [x.strip() for x in plot_pairs.split(',')]

    for pairs in plot_pairs:
        figure_file_name = result_dir + pairs + '.pdf'
        pair = [x.strip() for x in pairs.split('-')]
        x_y_labels: List = []
        x_y_values: Dict[str, List] = {}
        for item in pair:
            label = {
                'round': 'Round',
                'accuracy': 'Accuracy (%)',
                'elapsed_time': 'Wall clock time elapsed (s)',
                'round_time': 'Training time in each round (s)',
                'global_round': 'Global training round',
                'local_epoch_num': 'Local epochs',
                'edge_agg_num': 'Aggregation rounds on edge servers'
            }[item]
            x_y_labels.append(label)
            x_y_values[label] = result_dict[item]

        x_label = x_y_labels[0]
        y_label = x_y_labels[1]
        x_value = x_y_values[x_label]
        y_value = x_y_values[y_label]
        plot(x_label, x_value, y_label, y_value, figure_file_name)
Exemple #18
0
    def personalize_client_model(self, personalize_train_set):
        """"Run one step of gradient descent to personalze a client's model. """
        personalized_model = copy.deepcopy(self.model)
        personalized_model.to(self.device)
        personalized_model.train()

        loss_criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(personalized_model.parameters(),
                                    lr=Config().trainer.learning_rate,
                                    momentum=Config().trainer.momentum,
                                    weight_decay=Config().trainer.weight_decay)

        examples = personalize_train_set[0]
        labels = personalize_train_set[1]

        examples, labels = examples.to(self.device), labels.to(self.device)
        optimizer.zero_grad()

        outputs = personalized_model(examples)

        loss = loss_criterion(outputs, labels)

        loss.backward()

        optimizer.step()

        return personalized_model
Exemple #19
0
    def calc_corr(self, updates):
        """ Calculate the node contribution based on the angle
            between local gradient and global gradient.
        """
        correlations, contribs = [None] * len(updates), [None] * len(updates)

        # Update the baseline model weights
        curr_global_grads = self.process_grad(self.algorithm.extract_weights())
        if self.last_global_grads is None:
            self.last_global_grads = np.zeros(len(curr_global_grads))
        global_grads = np.subtract(curr_global_grads, self.last_global_grads)
        self.last_global_grads = curr_global_grads

        # Compute angles in radian between local and global gradients
        for i, update in enumerate(updates):
            local_grads = self.process_grad(update)
            inner = np.inner(global_grads, local_grads)
            norms = np.linalg.norm(global_grads) * np.linalg.norm(local_grads)
            correlations[i] = np.arccos(np.clip(inner / norms, -1.0, 1.0))

        for i, correlation in enumerate(correlations):
            self.local_correlations[i] = correlation
            self.local_correlations[i] = (
                (self.current_round - 1) /
                self.current_round) * self.local_correlations[i] + (
                    1 / self.current_round) * correlation
            # Non-linear mapping to node contribution
            contribs[i] = Config().algorithm.alpha * (
                1 - math.exp(-math.exp(-Config().algorithm.alpha *
                                       (self.local_correlations[i] - 1))))

        return contribs
Exemple #20
0
    def __init__(self, datasource, client_id, testing):
        super().__init__(datasource, client_id, testing)

        assert hasattr(Config().data, 'non_iid_clients')
        non_iid_clients = Config().data.non_iid_clients

        if isinstance(non_iid_clients, int):
            # If only one client's dataset is non-iid
            self.non_iid_clients_list = [int(non_iid_clients)]
        else:
            self.non_iid_clients_list = [
                int(x.strip()) for x in non_iid_clients.split(',')
            ]

        if int(client_id) not in self.non_iid_clients_list:
            if testing:
                target_list = datasource.get_test_set().targets
            else:
                target_list = datasource.targets()
            class_list = datasource.classes()
            self.sample_weights = np.array([
                1 / len(class_list) for _ in range(len(class_list))
            ])[target_list]

            # Different iid clients should have a different random seed for Generator
            self.random_seed = self.random_seed * int(client_id)
Exemple #21
0
    async def wrap_up_processing_reports(self):
        """Wrap up processing the reports with any additional work."""
        if hasattr(Config(), 'results'):
            new_row = []
            for item in self.recorded_items:
                item_value = self.get_record_items_values()[item]
                new_row.append(item_value)

            if Config().is_edge_server():
                result_csv_file = f"{Config().params['result_dir']}/edge_{os.getpid()}.csv"
            else:
                result_csv_file = f"{Config().params['result_dir']}/{os.getpid()}.csv"

            csv_processor.write_csv(result_csv_file, new_row)

        if Config().is_edge_server():
            # When a certain number of aggregations are completed, an edge client
            # needs to be signaled to send a report to the central server
            if self.current_round == Config().algorithm.local_rounds:
                logging.info(
                    '[Server #%d] Completed %s rounds of local aggregation.',
                    os.getpid(),
                    Config().algorithm.local_rounds)
                self.model_aggregated.set()

                self.current_round = 0
                self.current_global_round += 1
Exemple #22
0
    def resume_from_checkpoint(self):
        """ Resume a training session from a previously saved checkpoint. """
        logging.info(
            "[%s] Resume a training session from a previously saved checkpoint.",
            self)

        # Loading important data in the server for resuming its session
        checkpoint_dir = Config.params['checkpoint_dir']

        states_to_load = ['current_round', 'numpy_prng_state', 'prng_state']
        variables_to_load = {}

        for i, state in enumerate(states_to_load):
            with open(f"{checkpoint_dir}/{state}.pkl",
                      'rb') as checkpoint_file:
                variables_to_load[i] = pickle.load(checkpoint_file)

        self.current_round = variables_to_load[0]
        self.resumed_session = True
        numpy_prng_state = variables_to_load[1]
        prng_state = variables_to_load[2]

        np.random.set_state(numpy_prng_state)
        random.setstate(prng_state)

        model_name = Config().trainer.model_name if hasattr(
            Config().trainer, 'model_name') else 'custom'
        filename = f"checkpoint_{model_name}_{self.current_round}.pth"
        self.trainer.load_model(filename, checkpoint_dir)
Exemple #23
0
    def train(self, trainset, sampler, cut_layer=None) -> float:
        """The main training loop in a federated learning workload.

        Arguments:
        trainset: The training dataset.
        sampler: the sampler that extracts a partition for this client.
        cut_layer (optional): The layer which training should start from.
        """
        config = Config().trainer._asdict()
        config['run_id'] = Config().params['run_id']
        if hasattr(Config().trainer, 'max_concurrency'):
            # reserved for mp.Process
            self.start_training()
            tic = time.perf_counter()
            self.train_process(config, trainset, sampler, cut_layer)
            toc = time.perf_counter()
            self.pause_training()
        else:
            tic = time.perf_counter()
            self.train_process(config, trainset, sampler, cut_layer)
            toc = time.perf_counter()

        training_time = toc - tic

        return training_time
Exemple #24
0
    def load_data(self) -> None:
        """Generating data and loading them onto this client."""
        logging.info("[%s] Loading its data source...", self)

        if self.datasource is None or (hasattr(Config().data, 'reload_data')
                                       and Config().data.reload_data):
            self.datasource = datasources_registry.get(
                client_id=self.client_id)

        self.data_loaded = True

        logging.info("[%s] Dataset size: %s", self,
                     self.datasource.num_train_examples())

        # Setting up the data sampler
        self.sampler = samplers_registry.get(self.datasource, self.client_id)

        if hasattr(Config().trainer, 'use_mindspore'):
            # MindSpore requires samplers to be used while constructing
            # the dataset
            self.trainset = self.datasource.get_train_set(self.sampler)
        else:
            # PyTorch uses samplers when loading data with a data loader
            self.trainset = self.datasource.get_train_set()

        if Config().clients.do_test:
            # Set the testset if local testing is needed
            self.testset = self.datasource.get_test_set()
            if hasattr(Config().data, 'testset_sampler'):
                # Set the sampler for test set
                self.testset_sampler = samplers_registry.get(self.datasource,
                                                             self.client_id,
                                                             testing=True)
Exemple #25
0
    def update_policy(self):
        """ Update agent if needed in training mode. """
        logging.info("[RL Agent] Updating the policy.")
        if len(self.policy.replay_buffer) > Config().algorithm.batch_size:
            # TD3-LSTM
            critic_loss, actor_loss = self.policy.update()

            new_row = []
            for item in self.recorded_rl_items:
                item_value = {
                    'episode': self.current_episode,
                    'actor_loss': actor_loss,
                    'critic_loss': critic_loss
                }[item]
                new_row.append(item_value)

            episode_result_csv_file = f"{Config().params['result_dir']}/{os.getpid()}_episode_result.csv"
            csv_processor.write_csv(episode_result_csv_file, new_row)

        episode_reward_csv_file = f"{Config().params['result_dir']}/{os.getpid()}_episode_reward.csv"
        csv_processor.write_csv(episode_reward_csv_file, [
            self.current_episode, self.current_step,
            mean(self.pre_acc), self.episode_reward
        ])

        # Reinitialize the previous accuracy queue
        for _ in range(5):
            self.pre_acc.append(0)

        if self.current_episode % Config().algorithm.log_interval == 0:
            self.policy.save_model(self.current_episode)
Exemple #26
0
    def __init__(self, model=None):
        """Initializing the trainer with the provided model.

        Arguments:
        model: The model to train.
        client_id: The ID of the client using this trainer (optional).
        """
        super().__init__()

        if model is None:
            model = models_registry.get()

        # Use data parallelism if multiple GPUs are available and the configuration specifies it
        if Config().is_parallel():
            logging.info("Using Data Parallelism.")
            # DataParallel will divide and allocate batch_size to all available GPUs
            self.model = nn.DataParallel(model)
        else:
            self.model = model

        if hasattr(Config().trainer, 'differential_privacy') and Config(
        ).trainer.differential_privacy:
            logging.info("Using differential privacy during training.")

            errors = ModuleValidator.validate(self.model, strict=False)
            if len(errors) > 0:
                self.model = ModuleValidator.fix(self.model)
                errors = ModuleValidator.validate(self.model, strict=False)
                assert len(errors) == 0

            self.model = GradSampleModule(self.model)
Exemple #27
0
    def prune_updates(self, previous_weights):
        """ Prune aggregated updates. """

        updates = self.compute_weight_updates(previous_weights)
        updates_model = models_registry.get()
        updates_model.load_state_dict(updates, strict=True)

        parameters_to_prune = []
        for _, module in updates_model.named_modules():
            if isinstance(module, torch.nn.Conv2d) or isinstance(
                    module, torch.nn.Linear):
                parameters_to_prune.append((module, 'weight'))

        if hasattr(Config().clients, 'pruning_method') and Config(
        ).clients.pruning_method == 'random':
            pruning_method = prune.RandomUnstructured
        else:
            pruning_method = prune.L1Unstructured

        prune.global_unstructured(
            parameters_to_prune,
            pruning_method=pruning_method,
            amount=Config().clients.pruning_amount,
        )

        for module, name in parameters_to_prune:
            prune.remove(module, name)

        return updates_model.cpu().state_dict()
Exemple #28
0
    def __init__(self, datasource, client_id, testing):
        super().__init__()
        if testing:
            dataset = datasource.get_test_set()
        else:
            dataset = datasource.get_train_set()

        self.dataset_size = len(dataset)
        indices = list(range(self.dataset_size))
        np.random.seed(self.random_seed)
        np.random.shuffle(indices)

        partition_size = Config().data.partition_size
        total_clients = Config().clients.total_clients
        total_size = partition_size * total_clients

        # add extra samples to make it evenly divisible, if needed
        if len(indices) < total_size:
            while len(indices) < total_size:
                indices += indices[:(total_size - len(indices))]
        else:
            indices = indices[:total_size]
        assert len(indices) == total_size

        # Compute the indices of data in the subset for this client
        self.subset_indices = indices[(int(client_id) -
                                       1):total_size:total_clients]
Exemple #29
0
 def get_model():
     """Obtaining an instance of this model provided that the name is valid."""
     if hasattr(Config().trainer, 'model_config'):
         return Model(Config().trainer.model_config,
                      Config().data.num_classes)
     else:
         return Model('yolov5s.yaml', Config().data.num_classes)
Exemple #30
0
    def __init__(self, state_dim, action_dim, hidden_size):
        super(RNNCritic, self).__init__()
        self.action_dim = action_dim

        # Q1 architecture
        if hasattr(Config().server,
                   'synchronous') and not Config().server.synchronous:
            self.l1 = nn.LSTM(state_dim + 1, hidden_size, batch_first=True)
        else:
            self.l1 = nn.LSTM(state_dim + action_dim,
                              hidden_size,
                              batch_first=True)
        self.l2 = nn.Linear(hidden_size, hidden_size)
        self.l3 = nn.Linear(hidden_size, 1)

        # Q2 architecture
        if hasattr(Config().server,
                   'synchronous') and not Config().server.synchronous:
            self.l4 = nn.LSTM(state_dim + 1, hidden_size, batch_first=True)
        else:
            self.l4 = nn.LSTM(state_dim + action_dim,
                              hidden_size,
                              batch_first=True)
        self.l5 = nn.Linear(hidden_size, hidden_size)
        self.l6 = nn.Linear(hidden_size, 1)