Ejemplo n.º 1
0
    def test_milestones(self):
        self.assertLrEquals(0.1)

        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=UserWarning)

            fields = [
                'optimizer', 'lr_schedule', 'learning_rate', 'momentum',
                'weight_decay', 'lr_gamma', 'lr_milestone_steps'
            ]
            params = ['SGD', 'LambdaLR', 0.1, 0.5, 0.0, 0.1, '2ep,4ep,7ep,8ep']

            Config().trainer = namedtuple('trainer', fields)(*params)
            self.assertLrEquals(0.1)

            lrs = optimizers.get_lr_schedule(self.optimizer, 10)

            self.assertLrEquals(0.1)
            for _ in range(19):
                lrs.step()
            self.assertLrEquals(1e-1)

            for _ in range(1):
                lrs.step()
            self.assertLrEquals(1e-2)
            for _ in range(19):
                lrs.step()
            self.assertLrEquals(1e-2)

            for _ in range(1):
                lrs.step()
            self.assertLrEquals(1e-3)
            for _ in range(29):
                lrs.step()
            self.assertLrEquals(1e-3)

            for _ in range(1):
                lrs.step()
            self.assertLrEquals(1e-4)
            for _ in range(9):
                lrs.step()
            self.assertLrEquals(1e-4)

            for _ in range(1):
                lrs.step()
            self.assertLrEquals(1e-5)
            for _ in range(100):
                lrs.step()
            self.assertLrEquals(1e-5)
Ejemplo n.º 2
0
    def test_vanilla(self):
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=UserWarning)

            fields = [
                'optimizer', 'lr_schedule', 'learning_rate', 'momentum',
                'weight_decay', 'lr_gamma'
            ]
            params = ['SGD', 'LambdaLR', 0.1, 0.5, 0.0, 0.0]
            Config().trainer = namedtuple('trainer', fields)(*params)

            lrs = optimizers.get_lr_schedule(self.optimizer, 10)
            self.assertLrEquals(0.1)
            for _ in range(100):
                lrs.step()
            self.assertLrEquals(0.1)
            self.assertLrEquals(0.1)
Ejemplo n.º 3
0
    def train_process(self, config, trainset, evalset, sampler,
                      blending_weights):
        log_interval = config.log_config["interval"]
        batch_size = config.trainer['batch_size']

        logging.info("[Client #%d] Loading the dataset.", self.client_id)

        # prepare traindata loaders
        train_loader = torch.utils.data.DataLoader(dataset=trainset,
                                                   shuffle=False,
                                                   batch_size=batch_size,
                                                   sampler=sampler.get(),
                                                   num_workers=config.data.get(
                                                       'workers_per_gpu', 1))

        eval_loader = torch.utils.data.DataLoader(dataset=evalset,
                                                  shuffle=False,
                                                  batch_size=batch_size,
                                                  sampler=sampler.get(),
                                                  num_workers=config.data.get(
                                                      'workers_per_gpu', 1))

        iterations_per_epoch = np.ceil(len(trainset) / batch_size).astype(int)
        epochs = config['epochs']

        # Sending the model to the device used for training
        self.model.to(self.device)
        self.model.train()
        # Initializing the optimizer
        get_optimizer = getattr(self, "get_optimizer",
                                optimizers.get_optimizer)
        optimizer = get_optimizer(self.model)
        # Initializing the learning rate schedule, if necessary
        if hasattr(config, 'lr_schedule'):
            lr_schedule = optimizers.get_lr_schedule(optimizer,
                                                     iterations_per_epoch,
                                                     train_loader)
        else:
            lr_schedule = None

        # operate the local training
        supported_modalities = trainset.supported_modalities
        # in order to blend the gradients in the server side
        #   The eval/train loss of the first and last epoches should be recorded
        for epoch in range(1, epochs + 1):
            epoch_train_losses = {
                modl_nm: 0.0
                for modl_nm in supported_modalities
            }
            total_batches = 0
            total_epoch_loss = 0
            for batch_id, (multimodal_examples,
                           labels) in enumerate(train_loader):
                labels = labels.to(self.device)

                optimizer.zero_grad()

                losses = self.model.forward(data_container=multimodal_examples,
                                            label=labels,
                                            return_loss=True)

                weighted_losses = self.reweight_losses(blending_weights,
                                                       losses)

                # added the losses
                weighted_global_loss = 0
                for modl_nm in supported_modalities:
                    epoch_train_losses[modl_nm] += weighted_losses[modl_nm]
                    weighted_global_loss += weighted_losses[modl_nm]

                total_epoch_loss += weighted_global_loss

                weighted_global_loss.backward()

                optimizer.step()

                if lr_schedule is not None:
                    lr_schedule.step()

                if batch_id % log_interval == 0:
                    if self.client_id == 0:
                        logging.info(
                            "[Server #{}] Epoch: [{}/{}][{}/{}]\tLoss: {:.6f}".
                            format(os.getpid(), epoch, epochs, batch_id,
                                   len(train_loader),
                                   weighted_losses.data.item()))
                    else:
                        if hasattr(config, 'use_wandb'):
                            wandb.log(
                                {"batch loss": weighted_losses.data.item()})

                        logging.info(
                            "[Client #{}] Epoch: [{}/{}][{}/{}]\tLoss: {:.6f}".
                            format(self.client_id, epoch, epochs, batch_id,
                                   len(train_loader),
                                   weighted_losses.data.item()))
                total_batches = batch_id
            if hasattr(optimizer, "params_state_update"):
                optimizer.params_state_update()

            # only record the first and final performance of the local epoches
            if epoch == 1 or epoch == epochs:
                epoch_avg_train_loss = total_epoch_loss / (total_batches + 1)

                eval_avg_losses = self.eval_step(eval_data_loader=eval_loader)
                weighted_eval_losses = self.reweight_losses(
                    blending_weights, eval_avg_losses)
                total_eval_loss = 0
                for modl_nm in supported_modalities:
                    modl_train_avg_loss = epoch_train_losses[
                        modl_nm] / total_batches
                    modl_eval_avg_loss = eval_avg_losses[modl_nm]
                    if modl_nm not in list(
                            self.mm_train_losses_trajectory.keys()):
                        self.mm_train_losses_trajectory[
                            modl_nm] = modl_train_avg_loss
                    else:
                        self.mm_train_losses_trajectory[modl_nm].append(
                            modl_train_avg_loss)
                    if modl_nm not in list(
                            self.mm_val_losses_trajectory.keys()):
                        self.mm_val_losses_trajectory[
                            modl_nm] = modl_eval_avg_loss
                    else:
                        self.mm_val_losses_trajectory[modl_nm].append(
                            modl_eval_avg_loss)

                    total_eval_loss += weighted_eval_losses[modl_nm]

                # store the global losses
                self.global_mm_train_losses_trajectory.append(
                    epoch_avg_train_loss)
                self.global_mm_val_losses_trajectory.append(total_eval_loss)
        self.model.cpu()

        model_type = config['model_name']
        filename = f"{model_type}_{self.client_id}_{config['run_id']}.pth"
        self.save_model(filename)

        if 'use_wandb' in config:

            run = wandb.init(project="plato",
                             group=str(config['run_id']),
                             reinit=True)
        if 'use_wandb' in config:
            run.finish()
Ejemplo n.º 4
0
    def train_model(self, config, trainset, sampler, cut_layer=None):
        """ A custom trainer reporting training loss. """
        batch_size = config['batch_size']
        log_interval = 10
        tic = time.perf_counter()

        logging.info("[Client #%d] Loading the dataset.", self.client_id)
        _train_loader = getattr(self, "train_loader", None)

        if callable(_train_loader):
            train_loader = self.train_loader(batch_size, trainset, sampler,
                                             cut_layer)
        else:
            train_loader = torch.utils.data.DataLoader(dataset=trainset,
                                                       shuffle=False,
                                                       batch_size=batch_size,
                                                       sampler=sampler)

        iterations_per_epoch = np.ceil(len(trainset) / batch_size).astype(int)
        epochs = config['epochs']

        # Sending the model to the device used for training
        self.model.to(self.device)
        self.model.train()

        # Initializing the loss criterion
        _loss_criterion = getattr(self, "loss_criterion", None)
        if callable(_loss_criterion):
            loss_criterion = self.loss_criterion(self.model)
        else:
            loss_criterion = torch.nn.CrossEntropyLoss()

        # Initializing the optimizer
        get_optimizer = getattr(self, "get_optimizer",
                                optimizers.get_optimizer)
        optimizer = get_optimizer(self.model)

        # Initializing the learning rate schedule, if necessary
        if hasattr(config, 'lr_schedule'):
            lr_schedule = optimizers.get_lr_schedule(optimizer,
                                                     iterations_per_epoch,
                                                     train_loader)
        else:
            lr_schedule = None

        if 'differential_privacy' in config and config['differential_privacy']:
            privacy_engine = PrivacyEngine(accountant='rdp', secure_mode=False)

            self.model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
                module=self.model,
                optimizer=optimizer,
                data_loader=train_loader,
                target_epsilon=config['dp_epsilon']
                if 'dp_epsilon' in config else 10.0,
                target_delta=config['dp_delta']
                if 'dp_delta' in config else 1e-5,
                epochs=epochs,
                max_grad_norm=config['dp_max_grad_norm']
                if 'max_grad_norm' in config else 1.0,
            )

        for epoch in range(1, epochs + 1):
            # Use a default training loop
            for batch_id, (examples, labels) in enumerate(train_loader):
                examples, labels = examples.to(self.device), labels.to(
                    self.device)
                if 'differential_privacy' in config and config[
                        'differential_privacy']:
                    optimizer.zero_grad(set_to_none=True)
                else:
                    optimizer.zero_grad()

                if cut_layer is None:
                    outputs = self.model(examples)
                else:
                    outputs = self.model.forward_from(examples, cut_layer)

                loss = loss_criterion(outputs, labels)

                loss.backward()
                optimizer.step()

                if batch_id % log_interval == 0:
                    if self.client_id == 0:
                        logging.info(
                            "[Server #%d] Epoch: [%d/%d][%d/%d]\tLoss: %.6f",
                            os.getpid(), epoch, epochs, batch_id,
                            len(train_loader), loss.data.item())
                    else:
                        logging.info(
                            "[Client #%d] Epoch: [%d/%d][%d/%d]\tLoss: %.6f",
                            self.client_id, epoch, epochs, batch_id,
                            len(train_loader), loss.data.item())

            if lr_schedule is not None:
                lr_schedule.step()

            if hasattr(optimizer, "params_state_update"):
                optimizer.params_state_update()

            # Simulate client's speed
            if self.client_id != 0 and hasattr(
                    Config().clients,
                    "speed_simulation") and Config().clients.speed_simulation:
                self.simulate_sleep_time()

            # Saving the model at the end of this epoch to a file so that
            # it can later be retrieved to respond to server requests
            # in asynchronous mode when the wall clock time is simulated
            if hasattr(Config().server,
                       'request_update') and Config().server.request_update:
                self.model.cpu()
                training_time = time.perf_counter() - tic
                filename = f"{self.client_id}_{epoch}_{training_time}.pth"
                self.save_model(filename)
                self.model.to(self.device)

        # Save the training loss of the last epoch in this round
        model_name = config['model_name']
        filename = f'{model_name}_{self.client_id}.loss'
        Trainer.save_loss(loss.data.item(), filename)
Ejemplo n.º 5
0
    def train_model(self, config, trainset, sampler, cut_layer=None):
        batch_size = config['batch_size']

        logging.info("[Client #%d] Loading the dataset.", self.client_id)
        _train_loader = getattr(self, "train_loader", None)

        if callable(_train_loader):
            train_loader = self.train_loader(batch_size, trainset, sampler,
                                             cut_layer)
        else:
            train_loader = torch.utils.data.DataLoader(dataset=trainset,
                                                       shuffle=False,
                                                       batch_size=batch_size,
                                                       sampler=sampler)

        iterations_per_epoch = np.ceil(len(trainset) / batch_size).astype(int)

        # Sending the model to the device used for training
        self.model.to(self.device)
        self.model.train()

        # Initializing the loss criterion
        _loss_criterion = getattr(self, "loss_criterion", None)
        if callable(_loss_criterion):
            loss_criterion = self.loss_criterion(self.model)
        else:
            loss_criterion = nn.CrossEntropyLoss()

        # Initializing the optimizer
        get_optimizer = getattr(self, "get_optimizer",
                                optimizers.get_optimizer)
        optimizer = get_optimizer(self.model)

        # Initializing the learning rate schedule, if necessary
        if hasattr(Config().trainer, 'lr_schedule'):
            lr_schedule = optimizers.get_lr_schedule(optimizer,
                                                     iterations_per_epoch,
                                                     train_loader)
        else:
            lr_schedule = None

        logging.info("[Client #%d] Begining to train.", self.client_id)
        for __, (examples, labels) in enumerate(train_loader):
            examples, labels = examples.to(self.device), labels.to(self.device)
            optimizer.zero_grad()

            examples = examples.detach().requires_grad_(True)

            if cut_layer is None:
                outputs = self.model(examples)
            else:
                outputs = self.model.forward_from(examples, cut_layer)

            loss = loss_criterion(outputs, labels)
            logging.info("[Client #{}] \tLoss: {:.6f}".format(
                self.client_id, loss.data.item()))
            loss.backward()

            # Record gradients within the cut layer
            self.cut_layer_grad.append(examples.grad.clone().detach())

            optimizer.step()

            if lr_schedule is not None:
                lr_schedule.step()

        if hasattr(optimizer, "params_state_update"):
            optimizer.params_state_update()

        self.save_gradients()
Ejemplo n.º 6
0
    def train_process(self, config, trainset, sampler, cut_layer=None):
        """The main training loop in a federated learning workload, run in
          a separate process with a new CUDA context, so that CUDA memory
          can be released after the training completes.

        Arguments:
        self: the trainer itself.
        config: a dictionary of configuration parameters.
        trainset: The training dataset.
        sampler: the sampler that extracts a partition for this client.
        cut_layer (optional): The layer which training should start from.
        """
        if 'use_wandb' in config:
            import wandb

            run = wandb.init(project="plato",
                             group=str(config['run_id']),
                             reinit=True)

        try:
            custom_train = getattr(self, "train_model", None)

            if callable(custom_train):
                self.train_model(config, trainset, sampler.get(), cut_layer)
            else:
                log_interval = 10
                batch_size = config['batch_size']

                logging.info("[Client #%d] Loading the dataset.",
                             self.client_id)
                _train_loader = getattr(self, "train_loader", None)

                if callable(_train_loader):
                    train_loader = self.train_loader(batch_size, trainset,
                                                     sampler.get(), cut_layer)
                else:
                    train_loader = torch.utils.data.DataLoader(
                        dataset=trainset,
                        shuffle=False,
                        batch_size=batch_size,
                        sampler=sampler.get())

                iterations_per_epoch = np.ceil(len(trainset) /
                                               batch_size).astype(int)
                epochs = config['epochs']

                # Sending the model to the device used for training
                self.model.to(self.device)
                self.model.train()

                # Initializing the loss criterion
                _loss_criterion = getattr(self, "loss_criterion", None)
                if callable(_loss_criterion):
                    loss_criterion = self.loss_criterion(self.model)
                else:
                    loss_criterion = nn.CrossEntropyLoss()

                # Initializing the optimizer
                get_optimizer = getattr(self, "get_optimizer",
                                        optimizers.get_optimizer)
                optimizer = get_optimizer(self.model)

                # Initializing the learning rate schedule, if necessary
                if hasattr(config, 'lr_schedule'):
                    lr_schedule = optimizers.get_lr_schedule(
                        optimizer, iterations_per_epoch, train_loader)
                else:
                    lr_schedule = None

                if 'differential_privacy' in config and config[
                        'differential_privacy']:
                    privacy_engine = PrivacyEngine(accountant='rdp',
                                                   secure_mode=False)

                    self.model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
                        module=self.model,
                        optimizer=optimizer,
                        data_loader=train_loader,
                        target_epsilon=config['dp_epsilon']
                        if 'dp_epsilon' in config else 10.0,
                        target_delta=config['dp_delta']
                        if 'dp_delta' in config else 1e-5,
                        epochs=epochs,
                        max_grad_norm=config['dp_max_grad_norm']
                        if 'max_grad_norm' in config else 1.0,
                    )

                for epoch in range(1, epochs + 1):
                    for batch_id, (examples,
                                   labels) in enumerate(train_loader):
                        examples, labels = examples.to(self.device), labels.to(
                            self.device)
                        if 'differential_privacy' in config and config[
                                'differential_privacy']:
                            optimizer.zero_grad(set_to_none=True)
                        else:
                            optimizer.zero_grad()

                        if cut_layer is None:
                            outputs = self.model(examples)
                        else:
                            outputs = self.model.forward_from(
                                examples, cut_layer)

                        loss = loss_criterion(outputs, labels)

                        loss.backward()
                        optimizer.step()

                        if batch_id % log_interval == 0:
                            if self.client_id == 0:
                                logging.info(
                                    "[Server #{}] Epoch: [{}/{}][{}/{}]\tLoss: {:.6f}"
                                    .format(os.getpid(), epoch, epochs,
                                            batch_id, len(train_loader),
                                            loss.data.item()))
                            else:
                                if hasattr(config, 'use_wandb'):
                                    wandb.log({"batch loss": loss.data.item()})

                                logging.info(
                                    "[Client #{}] Epoch: [{}/{}][{}/{}]\tLoss: {:.6f}"
                                    .format(self.client_id, epoch, epochs,
                                            batch_id, len(train_loader),
                                            loss.data.item()))

                    if lr_schedule is not None:
                        lr_schedule.step()

                    if hasattr(optimizer, "params_state_update"):
                        optimizer.params_state_update()

        except Exception as training_exception:
            logging.info("Training on client #%d failed.", self.client_id)
            raise training_exception

        if 'max_concurrency' in config:
            self.model.cpu()
            model_type = config['model_name']
            filename = f"{model_type}_{self.client_id}_{config['run_id']}.pth"
            self.save_model(filename)

        if 'use_wandb' in config:
            run.finish()
Ejemplo n.º 7
0
    def train_model(self, config, trainset, sampler, cut_layer):
        """ The custom training loop for Sub-FedAvg(Un). """
        batch_size = config['batch_size']
        log_interval = 10

        logging.info("[Client #%d] Loading the dataset.", self.client_id)
        _train_loader = getattr(self, "train_loader", None)

        if callable(_train_loader):
            train_loader = self.train_loader(batch_size, trainset, sampler,
                                             cut_layer)
        else:
            train_loader = torch.utils.data.DataLoader(dataset=trainset,
                                                       shuffle=False,
                                                       batch_size=batch_size,
                                                       sampler=sampler)

        iterations_per_epoch = np.ceil(len(trainset) / batch_size).astype(int)
        epochs = config['epochs']

        if not self.made_init_mask:
            self.mask = pruning_processor.make_init_mask(self.model)
            self.made_init_mask = True

        # Sending the model to the device used for training
        self.model.to(self.device)
        self.model.train()

        # Initializing the loss criterion
        _loss_criterion = getattr(self, "loss_criterion", None)
        if callable(_loss_criterion):
            loss_criterion = self.loss_criterion(self.model)
        else:
            loss_criterion = torch.nn.CrossEntropyLoss()

        # Initializing the optimizer
        get_optimizer = getattr(self, "get_optimizer",
                                optimizers.get_optimizer)
        optimizer = get_optimizer(self.model)

        # Initializing the learning rate schedule, if necessary
        if hasattr(config, 'lr_schedule'):
            lr_schedule = optimizers.get_lr_schedule(optimizer,
                                                     iterations_per_epoch,
                                                     train_loader)
        else:
            lr_schedule = None

        for epoch in range(1, epochs + 1):
            # Use a default training loop
            for batch_id, (examples, labels) in enumerate(train_loader):
                examples, labels = examples.to(self.device), labels.to(
                    self.device)

                optimizer.zero_grad()

                if cut_layer is None:
                    outputs = self.model(examples)
                else:
                    outputs = self.model.forward_from(examples, cut_layer)

                loss = loss_criterion(outputs, labels)

                loss.backward()

                # Freezing Pruned weights by making their gradients Zero
                step = 0
                for name, parameter in self.model.named_parameters():
                    if 'weight' in name:
                        grad_tensor = parameter.grad.data.cpu().numpy()
                        grad_tensor = grad_tensor * self.mask[step]
                        parameter.grad.data = torch.from_numpy(grad_tensor).to(
                            self.device)
                        step = step + 1

                optimizer.step()

                if batch_id % log_interval == 0:
                    if self.client_id == 0:
                        logging.info(
                            "[Server #%d] Epoch: [%d/%d][%d/%d]\tLoss: %.6f",
                            os.getpid(), epoch, epochs, batch_id,
                            len(train_loader), loss.data.item())

            if lr_schedule is not None:
                lr_schedule.step()

            if epoch == 1:
                first_epoch_mask = pruning_processor.fake_prune(
                    self.pruning_amount, copy.deepcopy(self.model),
                    copy.deepcopy(self.mask))
            if epoch == epochs:
                last_epoch_mask = pruning_processor.fake_prune(
                    self.pruning_amount, copy.deepcopy(self.model),
                    copy.deepcopy(self.mask))

        self.process_pruning(first_epoch_mask, last_epoch_mask)
Ejemplo n.º 8
0
    def train_process(self, config, trainset, sampler, cut_layer=None):
        """The main training loop in a federated learning workload."""

        if 'use_wandb' in config:
            import wandb

            run = wandb.init(project="plato",
                             group=str(config['run_id']),
                             reinit=True)

        try:
            custom_train = getattr(self, "train_model", None)

            if callable(custom_train):
                self.train_model(config, trainset, sampler.get(), cut_layer)
            else:
                log_interval = 10
                batch_size = config['batch_size']

                logging.info("[Client #%d] Loading the dataset.",
                             self.client_id)
                _train_loader = getattr(self, "train_loader", None)

                if callable(_train_loader):
                    train_loader = self.train_loader(batch_size, trainset,
                                                     sampler.get(), cut_layer)
                else:
                    train_loader = torch.utils.data.DataLoader(
                        dataset=trainset,
                        shuffle=False,
                        batch_size=batch_size,
                        sampler=sampler.get())

                iterations_per_epoch = np.ceil(len(trainset) /
                                               batch_size).astype(int)
                epochs = config['epochs']

                # Sending the model to the device used for training
                self.model.to(self.device)
                self.model.train()

                # Initializing the loss criterion
                _loss_criterion = getattr(self, "loss_criterion", None)
                if callable(_loss_criterion):
                    loss_criterion = self.loss_criterion(self.model)
                else:
                    loss_criterion = nn.CrossEntropyLoss()

                # Initializing the optimizer for the second stage of MAML
                # The learning rate here is the meta learning rate (beta)
                optimizer = torch.optim.SGD(
                    self.model.parameters(),
                    lr=Config().trainer.meta_learning_rate,
                    momentum=Config().trainer.momentum,
                    weight_decay=Config().trainer.weight_decay)

                # Initializing the schedule for meta learning rate, if necessary
                if hasattr(config, 'meta_lr_schedule'):
                    meta_lr_schedule = optimizers.get_lr_schedule(
                        optimizer, iterations_per_epoch, train_loader)
                else:
                    meta_lr_schedule = None

                for epoch in range(1, epochs + 1):
                    # Copy the current model due to using MAML
                    current_model = copy.deepcopy(self.model)
                    # Sending this model to the device used for training
                    current_model.to(self.device)
                    current_model.train()

                    # Initializing the optimizer for the first stage of MAML
                    # The learning rate here is the alpha in the paper
                    temp_optimizer = torch.optim.SGD(
                        current_model.parameters(),
                        lr=Config().trainer.learning_rate,
                        momentum=Config().trainer.momentum,
                        weight_decay=Config().trainer.weight_decay)

                    # Initializing the learning rate schedule, if necessary
                    if hasattr(config, 'lr_schedule'):
                        lr_schedule = optimizers.get_lr_schedule(
                            temp_optimizer, iterations_per_epoch, train_loader)
                    else:
                        lr_schedule = None

                    # The first stage of MAML
                    # Use half of the training dataset
                    self.training_per_stage(1, temp_optimizer, lr_schedule,
                                            train_loader, cut_layer,
                                            current_model, loss_criterion,
                                            log_interval, config, epoch,
                                            epochs)

                    # The second stage of MAML
                    # Use the other half of the training dataset
                    self.training_per_stage(2, optimizer, meta_lr_schedule,
                                            train_loader, cut_layer,
                                            self.model, loss_criterion,
                                            log_interval, config, epoch,
                                            epochs)

                    if hasattr(optimizer, "params_state_update"):
                        optimizer.params_state_update()

        except Exception as training_exception:
            logging.info("Training on client #%d failed.", self.client_id)
            raise training_exception

        if 'max_concurrency' in config:
            self.model.cpu()
            model_type = config['model_name']
            filename = f"{model_type}_{self.client_id}_{config['run_id']}.pth"
            self.save_model(filename)

        if 'use_wandb' in config:
            run.finish()
Ejemplo n.º 9
0
    def train_model(self, config, trainset, sampler, cut_layer=None):
        """A custom training loop for personalized FL."""
        batch_size = config['batch_size']
        log_interval = 10

        logging.info("[Client #%d] Loading the dataset.", self.client_id)
        _train_loader = getattr(self, "train_loader", None)

        if callable(_train_loader):
            train_loader = self.train_loader(batch_size, trainset, sampler,
                                             cut_layer)
        else:
            train_loader = torch.utils.data.DataLoader(dataset=trainset,
                                                       shuffle=False,
                                                       batch_size=batch_size,
                                                       sampler=sampler)

        iterations_per_epoch = np.ceil(len(trainset) / batch_size).astype(int)
        epochs = config['epochs']

        # Sending the model to the device used for training
        self.model.to(self.device)
        self.model.train()

        # Initializing the loss criterion
        _loss_criterion = getattr(self, "loss_criterion", None)
        if callable(_loss_criterion):
            loss_criterion = self.loss_criterion(self.model)
        else:
            loss_criterion = nn.CrossEntropyLoss()

        # Initializing the optimizer for the second stage of MAML
        # The learning rate here is the meta learning rate (beta)
        optimizer = torch.optim.SGD(self.model.parameters(),
                                    lr=Config().trainer.meta_learning_rate,
                                    momentum=Config().trainer.momentum,
                                    weight_decay=Config().trainer.weight_decay)

        # Initializing the schedule for meta learning rate, if necessary
        if hasattr(config, 'meta_lr_schedule'):
            meta_lr_schedule = optimizers.get_lr_schedule(
                optimizer, iterations_per_epoch, train_loader)
        else:
            meta_lr_schedule = None

        for epoch in range(1, epochs + 1):
            # Copy the current model due to using MAML
            current_model = copy.deepcopy(self.model)
            # Sending this model to the device used for training
            current_model.to(self.device)
            current_model.train()

            # Initializing the optimizer for the first stage of MAML
            # The learning rate here is the alpha in the paper
            temp_optimizer = torch.optim.SGD(
                current_model.parameters(),
                lr=Config().trainer.learning_rate,
                momentum=Config().trainer.momentum,
                weight_decay=Config().trainer.weight_decay)

            # Initializing the learning rate schedule, if necessary
            if hasattr(config, 'lr_schedule'):
                lr_schedule = optimizers.get_lr_schedule(
                    temp_optimizer, iterations_per_epoch, train_loader)
            else:
                lr_schedule = None

            # The first stage of MAML
            # Use half of the training dataset
            self.training_per_stage(1, temp_optimizer, lr_schedule,
                                    train_loader, cut_layer, current_model,
                                    loss_criterion, log_interval, epoch,
                                    epochs)

            # The second stage of MAML
            # Use the other half of the training dataset
            self.training_per_stage(2, optimizer, meta_lr_schedule,
                                    train_loader, cut_layer, self.model,
                                    loss_criterion, log_interval, epoch,
                                    epochs)

            if hasattr(optimizer, "params_state_update"):
                optimizer.params_state_update()
Ejemplo n.º 10
0
    def train_model(self, config, trainset, sampler, cut_layer=None):
        """A custom trainer reporting training loss. """
        log_interval = 10
        batch_size = config['batch_size']

        logging.info("[Client #%d] Loading the dataset.", self.client_id)

        train_loader = torch.utils.data.DataLoader(dataset=trainset,
                                                   shuffle=False,
                                                   batch_size=batch_size,
                                                   sampler=sampler)

        iterations_per_epoch = np.ceil(len(trainset) / batch_size).astype(int)
        epochs = config['epochs']

        # Sending the model to the device used for training
        self.model.to(self.device)
        self.model.train()

        # Initializing the loss criterion
        _loss_criterion = getattr(self, "loss_criterion", None)
        if callable(_loss_criterion):
            loss_criterion = self.loss_criterion(self.model)
        else:
            loss_criterion = nn.CrossEntropyLoss()

        # Initializing the optimizer
        get_optimizer = getattr(self, "get_optimizer",
                                optimizers.get_optimizer)
        optimizer = get_optimizer(self.model)

        # Initializing the learning rate schedule, if necessary
        if hasattr(config, 'lr_schedule'):
            lr_schedule = optimizers.get_lr_schedule(optimizer,
                                                     iterations_per_epoch,
                                                     train_loader)
        else:
            lr_schedule = None

        try:
            for epoch in range(1, epochs + 1):
                for batch_id, (examples, labels) in enumerate(train_loader):
                    examples, labels = examples.to(self.device), labels.to(
                        self.device)
                    optimizer.zero_grad()

                    if cut_layer is None:
                        outputs = self.model(examples)
                    else:
                        outputs = self.model.forward_from(examples, cut_layer)

                    loss = loss_criterion(outputs, labels)

                    loss.backward()

                    optimizer.step()

                    if lr_schedule is not None:
                        lr_schedule.step()

                    if batch_id % log_interval == 0:
                        if self.client_id == 0:
                            logging.info(
                                "[Server #{}] Epoch: [{}/{}][{}/{}]\tLoss: {:.6f}"
                                .format(os.getpid(), epoch, epochs, batch_id,
                                        len(train_loader), loss.data.item()))
                        else:
                            logging.info(
                                "[Client #{}] Epoch: [{}/{}][{}/{}]\tLoss: {:.6f}"
                                .format(self.client_id, epoch, epochs,
                                        batch_id, len(train_loader),
                                        loss.data.item()))

                if hasattr(optimizer, "params_state_update"):
                    optimizer.params_state_update()

        except Exception as training_exception:
            logging.info("Training on client #%d failed.", self.client_id)
            raise training_exception

        if 'max_concurrency' in config:
            self.model.cpu()
            model_type = config['model_name']
            filename = f"{model_type}_{self.client_id}_{config['run_id']}.pth"
            self.save_model(filename)

        # Save the training loss of the last epoch in this round
        model_name = config['model_name']
        filename = f"{model_name}_{self.client_id}_{config['run_id']}.loss"
        Trainer.save_loss(loss.data.item(), filename)