def get_towers(module_list: torch.nn.ModuleList, path_head, inchannels, outchannels, towernum=8, kernel_list=[3, 5, 0]): num_choice_kernel = len(kernel_list) for tower_idx in range(towernum): block_idx = path_head[1][tower_idx] kernel_sz = kernel_list[block_idx] if tower_idx == 0: assert (kernel_sz != 0) padding = (kernel_sz - 1) // 2 module_list.append( SeparableConv2d_BNReLU(inchannels, outchannels, kernel_size=kernel_sz, stride=1, padding=padding, dilation=1)) else: if block_idx != num_choice_kernel - 1: # else skip assert (kernel_sz != 0) padding = (kernel_sz - 1) // 2 module_list.append( SeparableConv2d_BNReLU(outchannels, outchannels, kernel_size=kernel_sz, stride=1, padding=padding, dilation=1)) return module_list
def set_weights(model: torch.nn.ModuleList, weights: fl.common.Weights) -> None: """Set model weights from a list of NumPy ndarrays.""" state_dict = OrderedDict({ k: torch.Tensor(np.atleast_1d(v)) for k, v in zip(model.state_dict().keys(), weights) }) model.load_state_dict(state_dict, strict=True)
def train( net: torch.nn.ModuleList, trainloader: torch.utils.data.DataLoader, epochs: int, device: torch.device, ) -> None: """Train the network.""" # Define loss and optimizer criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adadelta(net.parameters(), lr=1.0) print(f"Training {epochs} epoch(s) w/ {len(trainloader)} batches each") # Train the network for epoch in range(epochs): # loop over the dataset multiple times running_loss = 0.0 acc1 = 0.0 acc5 = 0.0 for i, data in enumerate(tqdm(trainloader), 0): images, labels = data[0].to(device), data[1].to(device) # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize outputs = net(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() # print statistics running_loss += loss.item() tmp1, tmp2 = accuracy(outputs, labels, topk=(1, 5)) acc1, acc5 = acc1 + tmp1, acc5 + tmp2 if i % 5 == 4: # print every 5 mini-batches print( "[%d, %5d] loss: %.3f acc1: %.3f acc5: %.3f" % ( epoch + 1, i + 1, running_loss / (i + 1), acc1 / (i + 1), acc5 / (i + 1), ), flush=True, )
def get_weights(model: torch.nn.ModuleList) -> fl.common.Weights: """Get model weights as a list of NumPy ndarrays.""" return [val.cpu().numpy() for _, val in model.state_dict().items()]