def run(load_last_checkpoint=False): save_dir = f'{OUTPUT_PATH}/models/' os.makedirs(save_dir, exist_ok=True) neural_net = Net() loss_fn = Loss() optim = torch.optim.SGD(neural_net.parameters(), DEFAULT_LR, momentum=0.9, weight_decay=1e-4) starting_epoch = 0 initial_loss = None if load_last_checkpoint: model_paths = glob(f'''{save_dir}*.ckpt''') model_names = [int(i.split('/')[-1][:-5]) for i in model_paths] latest_model_path = f'''{save_dir}{max(model_names)}.ckpt''' print('loading latest model from:', latest_model_path) checkpoint = torch.load(latest_model_path) neural_net.load_state_dict(checkpoint['model_state_dict']) optim.load_state_dict(checkpoint['optimizer_state_dict']) starting_epoch = checkpoint['epoch'] initial_loss = checkpoint['loss'] if torch.cuda.is_available(): neural_net = neural_net.cuda() loss_fn = loss_fn.cuda() print(f'''Training from epoch: {starting_epoch} towards: {TOTAL_EPOCHS}, with learning rate starting from: {get_lr(starting_epoch)}, and loss: {initial_loss}''') meta = pd.read_csv(f'{OUTPUT_PATH}/augmented_meta.csv', index_col=0).sample(frac=1).reset_index(drop=True) meta_group_by_series = meta.groupby(['seriesuid']).indices list_of_groups = [{i: list(meta_group_by_series[i])} for i in meta_group_by_series.keys()] random.Random(0).shuffle(list_of_groups) val_split = int(VAL_PCT * len(list_of_groups)) val_indices = list(itertools.chain(*[list(i.values())[0] for i in list_of_groups[:val_split]])) train_indices = list(itertools.chain(*[list(i.values())[0] for i in list_of_groups[val_split:]])) ltd = LunaDataSet(train_indices, meta) lvd = LunaDataSet(val_indices, meta) train_loader = DataLoader(ltd, batch_size=1, shuffle=False) val_loader = DataLoader(lvd, batch_size=1, shuffle=False) for ep in range(starting_epoch, TOTAL_EPOCHS): train(train_loader, neural_net, loss_fn, ep, optim, get_lr, save_dir=save_dir) validate(val_loader, neural_net, loss_fn)
def main(opt): if torch.cuda.is_available(): print('Will compute using CUDA') # torch.distributed.init_process_group(backend='nccl', init_method='env://') # num_gpus = torch.distributed.get_world_size() num_gpus = 1 torch.cuda.manual_seed(123) else: torch.manual_seed(123) num_gpus = 1 train_params = { "batch_size": opt.batch_size * num_gpus, "shuffle": True, "drop_last": False, "num_workers": opt.num_workers, "collate_fn": collate_fn } test_params = { "batch_size": opt.batch_size * num_gpus, "shuffle": True, "drop_last": False, "num_workers": opt.num_workers, "collate_fn": collate_fn } dboxes = generate_dboxes() model = SSD() train_set = OIDataset(SimpleTransformer(dboxes)) train_loader = DataLoader(train_set, **train_params) test_set = OIDataset(SimpleTransformer(dboxes, eval=True), train=False) test_loader = DataLoader(test_set, **test_params) encoder = Encoder(dboxes) opt.lr = opt.lr * num_gpus * (opt.batch_size / 32) criterion = Loss(dboxes) optimizer = torch.optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay, nesterov=True) scheduler = MultiStepLR(optimizer=optimizer, milestones=opt.multistep, gamma=0.1) if torch.cuda.is_available(): model.cuda() criterion.cuda() model = torch.nn.DataParallel(model) if os.path.isdir(opt.log_path): shutil.rmtree(opt.log_path) os.makedirs(opt.log_path) if not os.path.isdir(opt.save_folder): os.makedirs(opt.save_folder) checkpoint_path = os.path.join(opt.save_folder, "SSD.pth") writer = SummaryWriter(opt.log_path) if os.path.isfile(checkpoint_path): checkpoint = torch.load(checkpoint_path) first_epoch = checkpoint["epoch"] + 1 model.module.load_state_dict(checkpoint["model_state_dict"]) scheduler.load_state_dict(checkpoint["scheduler"]) optimizer.load_state_dict(checkpoint["optimizer"]) # evaluate(model, test_loader, encoder, opt.nms_threshold) else: first_epoch = 0 for epoch in range(first_epoch, opt.epochs): train(model, train_loader, epoch, writer, criterion, optimizer, scheduler) evaluate(model, test_loader, encoder, opt.nms_threshold) checkpoint = { "epoch": epoch, "model_state_dict": model.module.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict() } torch.save(checkpoint, checkpoint_path)