예제 #1
0
def get_towers(module_list: torch.nn.ModuleList,
               path_head,
               inchannels,
               outchannels,
               towernum=8,
               kernel_list=[3, 5, 0]):
    num_choice_kernel = len(kernel_list)
    for tower_idx in range(towernum):
        block_idx = path_head[1][tower_idx]
        kernel_sz = kernel_list[block_idx]
        if tower_idx == 0:
            assert (kernel_sz != 0)
            padding = (kernel_sz - 1) // 2
            module_list.append(
                SeparableConv2d_BNReLU(inchannels,
                                       outchannels,
                                       kernel_size=kernel_sz,
                                       stride=1,
                                       padding=padding,
                                       dilation=1))
        else:
            if block_idx != num_choice_kernel - 1:  # else skip
                assert (kernel_sz != 0)
                padding = (kernel_sz - 1) // 2
                module_list.append(
                    SeparableConv2d_BNReLU(outchannels,
                                           outchannels,
                                           kernel_size=kernel_sz,
                                           stride=1,
                                           padding=padding,
                                           dilation=1))
    return module_list
예제 #2
0
def set_weights(model: torch.nn.ModuleList,
                weights: fl.common.Weights) -> None:
    """Set model weights from a list of NumPy ndarrays."""
    state_dict = OrderedDict({
        k: torch.Tensor(np.atleast_1d(v))
        for k, v in zip(model.state_dict().keys(), weights)
    })
    model.load_state_dict(state_dict, strict=True)
예제 #3
0
def train(
    net: torch.nn.ModuleList,
    trainloader: torch.utils.data.DataLoader,
    epochs: int,
    device: torch.device,
) -> None:
    """Train the network."""

    # Define loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adadelta(net.parameters(), lr=1.0)

    print(f"Training {epochs} epoch(s) w/ {len(trainloader)} batches each")

    # Train the network
    for epoch in range(epochs):  # loop over the dataset multiple times
        running_loss = 0.0
        acc1 = 0.0
        acc5 = 0.0
        for i, data in enumerate(tqdm(trainloader), 0):
            images, labels = data[0].to(device), data[1].to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            tmp1, tmp2 = accuracy(outputs, labels, topk=(1, 5))
            acc1, acc5 = acc1 + tmp1, acc5 + tmp2
            if i % 5 == 4:  # print every 5 mini-batches
                print(
                    "[%d, %5d] loss: %.3f acc1: %.3f acc5: %.3f" % (
                        epoch + 1,
                        i + 1,
                        running_loss / (i + 1),
                        acc1 / (i + 1),
                        acc5 / (i + 1),
                    ),
                    flush=True,
                )
예제 #4
0
def get_weights(model: torch.nn.ModuleList) -> fl.common.Weights:
    """Get model weights as a list of NumPy ndarrays."""
    return [val.cpu().numpy() for _, val in model.state_dict().items()]