def main(): log_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'log', config.experiment + config.data) log_file = log_dir + '.txt' log_config(log_file) logging.info( '-------------------------------------------This is all configurations-----------------------------------------' ) logging.info(config) logging.info( '---------------------------------------------This is a halving line-------------------------------------------' ) logging.info('{}'.format(config.description)) torch.manual_seed(config.seed) torch.cuda.manual_seed(config.seed) random.seed(config.seed) np.random.seed(config.seed) #model = generate_model(config) model = getattr(models, config.model_name)() #model = getattr(models, config.model_name)(c=4,n=32,channels=128, groups=16,norm='sync_bn', num_classes=4,output_func='softmax') model = torch.nn.DataParallel(model).cuda() model.train() optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay, amsgrad=config.amsgrad) # criterion = getattr(criterions, config.criterion) criterion = torch.nn.CrossEntropyLoss() checkpoint_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'checkpoint', config.experiment + config.data) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) resume = os.path.join(os.path.abspath(os.path.dirname(__file__)), config.resume) if os.path.isfile(resume) and config.load: logging.info('loading checkpoint {}'.format(resume)) checkpoint = torch.load(resume) config.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optim_dict']) logging.info( 'Successfully loading checkpoint {} and training from epoch: {}'. format(config.resume, config.start_epoch)) else: logging.info('re-training!!!') train_list = os.path.join(config.root, config.train_dir, config.train_file) train_root = os.path.join(config.root, config.train_dir) train_set = BraTS(train_list, train_root, config.mode) logging.info('Samples for train = {}'.format(len(train_set))) num_iters = (len(train_set) * config.end_epoch) // config.batch_size num_iters -= (len(train_set) * config.start_epoch) // config.batch_size train_loader = DataLoader(dataset=train_set, shuffle=True, batch_size=config.batch_size, num_workers=config.num_workers, pin_memory=True) start_time = time.time() torch.set_grad_enabled(True) for epoch in range(config.start_epoch, config.end_epoch): loss_epoch = [] area1 = [] area2 = [] area4 = [] setproctitle.setproctitle('{}:{} {}/{}'.format(config.user, config.model_name, epoch + 1, config.end_epoch)) start_epoch = time.time() for i, data in enumerate(train_loader): adjust_learning_rate(optimizer, epoch, config.end_epoch, config.lr) #warm_up_learning_rate_adjust2(config.lr, epoch, config.warm, config.end_epoch, optimizer) data = [t.cuda(non_blocking=True) for t in data] x, target = data output = model(x) target[target == 4] = 3 loss = criterion(output, target) logging.info('Epoch: {}_Iter:{} loss: {:.5f} ||'.format( epoch, i, loss)) optimizer.zero_grad() loss.backward() optimizer.step() end_epoch = time.time() if (epoch + 1) % int(config.save_freq) == 0 \ or (epoch + 1) % int(config.end_epoch - 1) == 0 \ or (epoch + 1) % int(config.end_epoch - 2) == 0: file_name = os.path.join(checkpoint_dir, 'model_epoch_{}.pth'.format(epoch)) torch.save( { 'epoch': epoch, 'state_dict': model.state_dict(), 'optim_dict': optimizer.state_dict(), }, file_name) epoch_time_minute = (end_epoch - start_epoch) / 60 remaining_time_hour = (config.end_epoch - epoch - 1) * epoch_time_minute / 60 logging.info('Current epoch time consumption: {:.2f} minutes!'.format( epoch_time_minute)) logging.info('Estimated remaining training time: {:.2f} hours!'.format( remaining_time_hour)) final_name = os.path.join(checkpoint_dir, 'model_epoch_last.pth') torch.save( { 'epoch': config.end_epoch, 'state_dict': model.state_dict(), 'optim_dict': optimizer.state_dict(), }, final_name) end_time = time.time() total_time = (end_time - start_time) / 3600 logging.info('The total training time is {:.2f} hours'.format(total_time)) logging.info( '-----------------------------------The training process finished!------------------------------------' )
def main_worker(): if args.local_rank == 0: log_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'log', args.experiment+args.date) log_file = log_dir + '.txt' log_args(log_file) logging.info('--------------------------------------This is all argsurations----------------------------------') for arg in vars(args): logging.info('{}={}'.format(arg, getattr(args, arg))) logging.info('----------------------------------------This is a halving line----------------------------------') logging.info('{}'.format(args.description)) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) random.seed(args.seed) np.random.seed(args.seed) torch.distributed.init_process_group('nccl') torch.cuda.set_device(args.local_rank) _, model = TransBTS(dataset='brats', _conv_repr=True, _pe_type="learned") model.cuda(args.local_rank) model = nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) model.train() optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, amsgrad=args.amsgrad) criterion = getattr(criterions, args.criterion) if args.local_rank == 0: checkpoint_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'checkpoint', args.experiment+args.date) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) resume = '' writer = SummaryWriter() if os.path.isfile(resume) and args.load: logging.info('loading checkpoint {}'.format(resume)) checkpoint = torch.load(resume, map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint['state_dict']) logging.info('Successfully loading checkpoint {} and training from epoch: {}' .format(args.resume, args.start_epoch)) else: logging.info('re-training!!!') train_list = os.path.join(args.root, args.train_dir, args.train_file) train_root = os.path.join(args.root, args.train_dir) train_set = BraTS(train_list, train_root, args.mode) train_sampler = torch.utils.data.distributed.DistributedSampler(train_set) logging.info('Samples for train = {}'.format(len(train_set))) num_gpu = (len(args.gpu)+1) // 2 train_loader = DataLoader(dataset=train_set, sampler=train_sampler, batch_size=args.batch_size // num_gpu, drop_last=True, num_workers=args.num_workers, pin_memory=True) start_time = time.time() torch.set_grad_enabled(True) for epoch in range(args.start_epoch, args.end_epoch): train_sampler.set_epoch(epoch) # shuffle setproctitle.setproctitle('{}: {}/{}'.format(args.user, epoch+1, args.end_epoch)) start_epoch = time.time() for i, data in enumerate(train_loader): adjust_learning_rate(optimizer, epoch, args.end_epoch, args.lr) x, target = data x = x.cuda(args.local_rank, non_blocking=True) target = target.cuda(args.local_rank, non_blocking=True) output = model(x) loss, loss1, loss2, loss3 = criterion(output, target) reduce_loss = all_reduce_tensor(loss, world_size=num_gpu).data.cpu().numpy() reduce_loss1 = all_reduce_tensor(loss1, world_size=num_gpu).data.cpu().numpy() reduce_loss2 = all_reduce_tensor(loss2, world_size=num_gpu).data.cpu().numpy() reduce_loss3 = all_reduce_tensor(loss3, world_size=num_gpu).data.cpu().numpy() if args.local_rank == 0: logging.info('Epoch: {}_Iter:{} loss: {:.5f} || 1:{:.4f} | 2:{:.4f} | 3:{:.4f} ||' .format(epoch, i, reduce_loss, reduce_loss1, reduce_loss2, reduce_loss3)) optimizer.zero_grad() loss.backward() optimizer.step() end_epoch = time.time() if args.local_rank == 0: if (epoch + 1) % int(args.save_freq) == 0 \ or (epoch + 1) % int(args.end_epoch - 1) == 0 \ or (epoch + 1) % int(args.end_epoch - 2) == 0 \ or (epoch + 1) % int(args.end_epoch - 3) == 0: file_name = os.path.join(checkpoint_dir, 'model_epoch_{}.pth'.format(epoch)) torch.save({ 'epoch': epoch, 'state_dict': model.state_dict(), 'optim_dict': optimizer.state_dict(), }, file_name) writer.add_scalar('lr:', optimizer.param_groups[0]['lr'], epoch) writer.add_scalar('loss:', reduce_loss, epoch) writer.add_scalar('loss1:', reduce_loss1, epoch) writer.add_scalar('loss2:', reduce_loss2, epoch) writer.add_scalar('loss3:', reduce_loss3, epoch) if args.local_rank == 0: epoch_time_minute = (end_epoch-start_epoch)/60 remaining_time_hour = (args.end_epoch-epoch-1)*epoch_time_minute/60 logging.info('Current epoch time consumption: {:.2f} minutes!'.format(epoch_time_minute)) logging.info('Estimated remaining training time: {:.2f} hours!'.format(remaining_time_hour)) if args.local_rank == 0: writer.close() final_name = os.path.join(checkpoint_dir, 'model_epoch_last.pth') torch.save({ 'epoch': args.end_epoch, 'state_dict': model.state_dict(), 'optim_dict': optimizer.state_dict(), }, final_name) end_time = time.time() total_time = (end_time-start_time)/3600 logging.info('The total training time is {:.2f} hours'.format(total_time)) logging.info('----------------------------------The training process finished!-----------------------------------')
def main(): setattr(config, config.mode, 'train_fold') setattr(config, config.valid_file, 'valid_small.txt') torch.manual_seed(config.seed) torch.cuda.manual_seed(config.seed) random.seed(config.seed) np.random.seed(config.seed) model = getattr(models, config.model_name)(c=4, n=32, channels=128, groups=16, norm='sync_bn', num_classes=4, output_func='softmax') model = torch.nn.DataParallel(model).cuda() load_file = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'checkpoint', config.experiment + config.test_date, config.test_file) if os.path.exists(load_file): checkpoint = torch.load(load_file) model.load_state_dict(checkpoint['state_dict']) config.start_epoch = checkpoint['epoch'] print('Successfully load checkpoint {}'.format( os.path.join(config.experiment + config.test_date, config.test_file))) else: print('There is no resume file to load!') valid_list = os.path.join(config.root, config.train_dir, 'valid_small.txt') valid_root = os.path.join(config.root, config.train_dir) valid_set = BraTS(valid_list, valid_root, mode='train_fold') print('Samples for valid = {}'.format(len(valid_set))) valid_loader = DataLoader(valid_set, batch_size=1, shuffle=False, num_workers=config.num_workers, pin_memory=True) submission = os.path.join(os.path.abspath(os.path.dirname(__file__)), config.output_dir, config.submission, config.experiment + config.test_date) visual = os.path.join(os.path.abspath(os.path.dirname(__file__)), config.output_dir, 'visual_fold', config.experiment + config.test_date) if not os.path.exists(submission): os.makedirs(submission) if not os.path.exists(visual): os.makedirs(visual) start_time = time.time() with torch.no_grad(): validate_softmax(valid_loader=valid_loader, model=model, savepath=submission, visual=visual, names=valid_set.names, scoring=False, use_TTA=False, save_format=config.save_format, postprocess=True, snapshot=True) end_time = time.time() full_test_time = (end_time - start_time) / 60 average_time = full_test_time / len(valid_set) print('{:.2f} minutes!'.format(average_time))
def main(): torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) random.seed(args.seed) np.random.seed(args.seed) _, model = TransBTS(dataset='brats', _conv_repr=True, _pe_type="learned") model = torch.nn.DataParallel(model).cuda() load_file = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'checkpoint', args.experiment + args.test_date, args.test_file) if os.path.exists(load_file): checkpoint = torch.load(load_file) model.load_state_dict(checkpoint['state_dict']) args.start_epoch = checkpoint['epoch'] print('Successfully load checkpoint {}'.format( os.path.join(args.experiment + args.test_date, args.test_file))) else: print('There is no resume file to load!') valid_list = os.path.join(args.root, args.valid_dir, args.valid_file) valid_root = os.path.join(args.root, args.valid_dir) valid_set = BraTS(valid_list, valid_root, mode='test') print('Samples for valid = {}'.format(len(valid_set))) valid_loader = DataLoader(valid_set, batch_size=1, shuffle=False, num_workers=args.num_workers, pin_memory=True) submission = os.path.join(os.path.abspath(os.path.dirname(__file__)), args.output_dir, args.submission, args.experiment + args.test_date) visual = os.path.join(os.path.abspath(os.path.dirname(__file__)), args.output_dir, args.visual, args.experiment + args.test_date) if not os.path.exists(submission): os.makedirs(submission) if not os.path.exists(visual): os.makedirs(visual) start_time = time.time() with torch.no_grad(): validate_softmax(valid_loader=valid_loader, model=model, load_file=load_file, multimodel=False, savepath=submission, visual=visual, names=valid_set.names, use_TTA=args.use_TTA, save_format=args.save_format, snapshot=True, postprocess=True) end_time = time.time() full_test_time = (end_time - start_time) / 60 average_time = full_test_time / len(valid_set) print('{:.2f} minutes!'.format(average_time))