def Objective(trial): dim = trial.suggest_categorical('dim', [32, 64, 128]) #patch_size = trial.suggest_int('patch_size',7, 14, 7) patch_size = 7 depth = trial.suggest_categorical('depth', [8, 16, 32]) heads = trial.suggest_categorical('heads', [8, 16, 32]) mlp_dim = trial.suggest_categorical('mlp_dim', [128, 512, 1024]) optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop"]) lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True) print('dim:', dim, 'mlp_dim:', mlp_dim, 'depth:', depth, 'heads:', heads) model = ViT( dim=dim, image_size=28, patch_size=patch_size, num_classes=10, depth=depth, # number of transformer blocks heads=heads, # number of multi-channel attention mlp_dim=mlp_dim, channels=1, #dropout=0.2, ) # vanila cnn : 0.96 # model = Net() model.to(device) criterion = nn.CrossEntropyLoss() # optimizer #optimizer = optim.Adam(model.parameters(), lr=0.001) optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr) # scheduler scheduler = StepLR(optimizer, step_size=1, gamma=gamma) for epoch in range(1, epochs + 1): train(model, criterion, device, train_loader, optimizer, epoch) val_acc = test(model, device, test_loader) scheduler.step() if 0: torch.save(model.state_dict(), "mnist_cnn.pt") trial.report(val_acc, epoch) # Handle pruning based on the intermediate value. if trial.should_prune(): raise optuna.exceptions.TrialPruned() wandb.log({'val_acc': val_acc}) return val_acc
loss.backward() optimizer.step() acc = (output.argmax(dim=1) == label).float().mean() epoch_accuracy += acc / len(train_loader) epoch_loss += loss / len(train_loader) with torch.no_grad(): epoch_val_accuracy = 0 epoch_val_loss = 0 for data, label in valid_loader: data = data.to(device) label = label.to(device) val_output = model(data) val_loss = criterion(val_output, label) acc = (val_output.argmax(dim=1) == label).float().mean() epoch_val_accuracy += acc / len(valid_loader) epoch_val_loss += val_loss / len(valid_loader) print( f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n" ) torch.save(model.to('cpu'), f'./models/model-H-tiff-{epoch:08}.pth') model = model.to(device) # writer.add_scalar("epoch_loss", epoch_loss, epoch) # writer.add_scalar("epoch_accuracy", epoch_accuracy, epoch) # writer.add_scalar("epoch_val_loss", epoch_loss, epoch) # writer.add_scalar("epoch_val_accuracy", epoch_accuracy, epoch)