def init_ip_config(self): self.ip_train_loader = get_data_loader(EXP_NAME, data_type="train", batch_size=CLIENT_BATCH_SIZE, shuffle=True, num_workers=8, user_list=[0], pin_memory=True) self.ip_test_loader = get_data_loader(EXP_NAME, data_type="test", num_workers=8, pin_memory=True) ip_optimizer = SGD(self.model.parameters(), lr=INIT_LR) self.ip_optimizer_wrapper = OptimizerWrapper(self.model, ip_optimizer) self.ip_control = ControlModule(model=self.model, config=config)
def init_optimizer(self): self.optimizer = SGD(self.model.parameters(), lr=INIT_LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY) self.optimizer_scheduler = lr_scheduler.StepLR( self.optimizer, step_size=STEP_SIZE, gamma=0.5**(STEP_SIZE / LR_HALF_LIFE)) self.optimizer_wrapper = OptimizerWrapper(self.model, self.optimizer, self.optimizer_scheduler)
def init_optimizer(self): self.optimizer = SGD(self.model.parameters(), lr=INIT_LR) self.optimizer_scheduler = lr_scheduler.StepLR( self.optimizer, step_size=1, gamma=0.5**(1 / LR_HALF_LIFE)) self.optimizer_wrapper = OptimizerWrapper(self.model, self.optimizer, self.optimizer_scheduler)
def init_optimizer(self): self.optimizer = SGD(self.model.parameters(), lr=INIT_LR) self.optimizer_wrapper = OptimizerWrapper(self.model, self.optimizer)
server_pruning_rounds, num_pre_batch * config.CLIENT_BATCH_SIZE, server_adjust_interval)) server_loader = get_data_loader(config.EXP_NAME, data_type="train", batch_size=config.CLIENT_BATCH_SIZE, shuffle=True, num_workers=8, user_list=[0], pin_memory=True) server_inputs, server_outputs = [], [] for _ in range(num_pre_batch): inp, out = server_loader.get_next_batch() server_inputs.append(inp) server_outputs.append(out) server_optimizer = SGD(model.parameters(), lr=config.INIT_LR) server_optimizer_wrapper = OptimizerWrapper(model, server_optimizer) server_control = ControlModule(model=model, config=config) prev_density, prev_num, prev_ind = None, 5, [] for server_i in range(1, server_pruning_rounds + 1): for server_inp, server_out in zip(server_inputs, server_outputs): list_grad = server_optimizer_wrapper.step(server_inp, server_out) for (key, param), g in zip(model.named_parameters(), list_grad): assert param.size() == g.size() server_control.accumulate(key, g**2) if server_i % server_adjust_interval == 0: server_control.adjust(config.MAX_DEC_DIFF) cur_density = disp_num_params(model)