def test_sparse():
    weight = torch.randn(5, 1).cuda().requires_grad_()
    weight_sparse = weight.detach().clone().requires_grad_()
    optimizer_dense = MADGRAD([weight], lr=1e-3, momentum=0)
    optimizer_sparse = MADGRAD([weight_sparse], lr=1e-3, momentum=0)

    weight.grad = torch.rand_like(weight)
    weight.grad[0] = 0.0  # Add a zero
    weight_sparse.grad = weight.grad.to_sparse()

    optimizer_dense.step()
    optimizer_sparse.step()
    assert torch.allclose(weight, weight_sparse)

    weight.grad = torch.rand_like(weight)
    weight.grad[1] = 0.0  # Add a zero
    weight_sparse.grad = weight.grad.to_sparse()

    optimizer_dense.step()
    optimizer_sparse.step()
    assert torch.allclose(weight, weight_sparse)

    weight.grad = torch.rand_like(weight)
    weight.grad[0] = 0.0  # Add a zero
    weight_sparse.grad = weight.grad.to_sparse()

    optimizer_dense.step()
    optimizer_sparse.step()
    assert torch.allclose(weight, weight_sparse)
Пример #2
0
    def train(self):
        device = self.device
        print('Running on device: {}'.format(device), 'start training...')
        print(
            f'Setting - Epochs: {self.num_epochs}, Learning rate: {self.learning_rate} '
        )

        train_loader = self.train_loader
        valid_loader = self.valid_loader

        model = self.model.to(device)
        if self.optimizer == 0:
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=self.learning_rate,
                                         weight_decay=1e-5)
        elif self.optimizer == 1:
            optimizer = torch.optim.AdamW(model.parameters(),
                                          lr=self.learning_rate,
                                          weight_decay=1e-5)
        elif self.optimizer == 2:
            optimizer = MADGRAD(model.parameters(),
                                lr=self.learning_rate,
                                weight_decay=1e-5)
        elif self.optimizer == 3:
            optimizer = AdamP(model.parameters(),
                              lr=self.learning_rate,
                              weight_decay=1e-5)
        criterion = torch.nn.CrossEntropyLoss().to(device)

        if self.use_swa:
            optimizer = SWA(optimizer, swa_start=2, swa_freq=2, swa_lr=1e-5)

        # scheduler #
        scheduler_dct = {
            0:
            None,
            1:
            StepLR(optimizer, 10, gamma=0.5),
            2:
            ReduceLROnPlateau(optimizer,
                              'min',
                              factor=0.4,
                              patience=int(0.3 *
                                           self.early_stopping_patience)),
            3:
            CosineAnnealingLR(optimizer, T_max=5, eta_min=0.)
        }
        scheduler = scheduler_dct[self.scheduler]

        # early stopping
        early_stopping = EarlyStopping(patience=self.early_stopping_patience,
                                       verbose=True,
                                       path=f'checkpoint_{self.job}.pt')

        # training
        self.train_loss_lst = list()
        self.train_acc_lst = list()
        self.val_loss_lst = list()
        self.val_acc_lst = list()
        for epoch in range(1, self.num_epochs + 1):
            with tqdm(train_loader, unit='batch') as tepoch:
                avg_val_loss, avg_val_acc = None, None

                for idx, (img, label) in enumerate(tepoch):
                    tepoch.set_description(f"Epoch {epoch}")

                    model.train()
                    optimizer.zero_grad()

                    img, label = img.float().to(device), label.long().to(
                        device)

                    output = model(img)
                    loss = criterion(output, label)
                    predictions = output.argmax(dim=1, keepdim=True).squeeze()
                    correct = (predictions == label).sum().item()
                    accuracy = correct / len(img)

                    loss.backward()
                    optimizer.step()

                    if idx == len(train_loader) - 1:

                        val_loss_lst, val_acc_lst = list(), list()

                        model.eval()
                        with torch.no_grad():
                            for val_img, val_label in valid_loader:
                                val_img, val_label = val_img.float().to(
                                    device), val_label.long().to(device)

                                val_out = model(val_img)
                                val_loss = criterion(val_out, val_label)
                                val_pred = val_out.argmax(
                                    dim=1, keepdim=True).squeeze()
                                val_acc = (val_pred == val_label
                                           ).sum().item() / len(val_img)

                                val_loss_lst.append(val_loss.item())
                                val_acc_lst.append(val_acc)

                        avg_val_loss = np.mean(val_loss_lst)
                        avg_val_acc = np.mean(val_acc_lst) * 100.

                        self.train_loss_lst.append(loss)
                        self.train_acc_lst.append(accuracy)
                        self.val_loss_lst.append(avg_val_loss)
                        self.val_acc_lst.append(avg_val_acc)

                    if scheduler is not None:
                        current_lr = optimizer.param_groups[0]['lr']
                    else:
                        current_lr = self.learning_rate

                    # log
                    tepoch.set_postfix(loss=loss.item(),
                                       accuracy=100. * accuracy,
                                       val_loss=avg_val_loss,
                                       val_acc=avg_val_acc,
                                       current_lr=current_lr)

                # early stopping check
                early_stopping(avg_val_loss, model)
                if early_stopping.early_stop:
                    print("Early stopping")
                    break

                # scheduler update
                if scheduler is not None:
                    if self.scheduler == 2:
                        scheduler.step(avg_val_loss)
                    else:
                        scheduler.step()
        if self.use_swa:
            optimizer.swap_swa_sgd()
        self.model.load_state_dict(torch.load(f'checkpoint_{self.job}.pt'))