def run(): batch_size = 16 # 64 patch_size = 448 num_workers = 32 num_layers = [3, 4, 6, 3] # res34 # num_layers = [2,2,2,2] # res18 dropout_rate = 0.5 if not os.path.exists(TRAIN_SAVE_PATH): os.mkdir(TRAIN_SAVE_PATH) device_ids = [0] logger = logging.getLogger() logging.basicConfig(filename=TRAIN_SAVE_PATH + 'training.log', level=logging.DEBUG) logger.setLevel(logging.DEBUG) logging.info("Start Training") device = torch.device( 'cuda:{}'.format(','.join([str(i) for i in device_ids])) if torch.cuda.device_count() > 0 else torch.device('cpu')) model_ft = PatchCNN(layers=num_layers, dropout_rate=dropout_rate) model_ft = nn.DataParallel(model_ft, device_ids, dim=0) model_ft.to(device) criterion = CrossEntropyLoss2d() # Observe that all parameters are being optimized optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.01, momentum=0.9) # Decay LR by a factor of 0.1 every 7 epochs exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=20, gamma=0.1) dataloaders = { 'train': get_dataloader(batch_size=batch_size, patch_size=patch_size, root_dir=TRAIN_PATH, num_workers=num_workers), 'val': get_dataloader(batch_size=batch_size, patch_size=patch_size, root_dir=TRAIN_TEST_PATH, num_workers=num_workers) } model_ft = train_model(model_ft, dataloaders, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=100, save_epoch=10, display_size=100, save_path=TRAIN_SAVE_PATH) torch.save(model_ft.cpu().state_dict(), TRAIN_SAVE_PATH + '/PatchCNN_best.pth')
num_dilated_convs=0, dropout_min=0, dropout_max=0,\ block_type=config.block_type, padding=1, kernel_size=3,group_norm=0) # model = ResU(num_classes=config.n_classes,\ # num_blocks=config.num_blocks,\ # num_channels= config.num_channels,\ # strides = config.strides,block=Bottleneck) device_ids = config.device_ids device = torch.device('cuda:{}'.format(','.join([str(i) for i in device_ids])) \ if torch.cuda.device_count()>0 else torch.device('cpu')) model_ft = nn.DataParallel(model, device_ids, dim=0) model_ft.to(device) criterion = CrossEntropyLoss2d() # Observe that all parameters are being optimized if config.optim == 'sgd': optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.01, momentum=0.9) else: optimizer_ft = optim.RMSprop(model.parameters(), lr=config.lr, weight_decay=config.weight_decay) exp_lr_scheduler = lr_scheduler.MultiStepLR( optimizer_ft, milestones=config.milestones, gamma=0.1, )