def main(args): train_loader = get_loader(args) n_data = len(train_loader.dataset) logger.info(f"length of training dataset: {n_data}") model, model_ema = build_model(args) contrast = MemoryMoCo(128, args.nce_k, args.nce_t).cuda() criterion = NCESoftmaxLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), lr=args.batch_size * dist.get_world_size() / 256 * args.base_learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) scheduler = get_scheduler(optimizer, len(train_loader), args) if args.amp_opt_level != "O0": if amp is None: logger.warning(f"apex is not installed but amp_opt_level is set to {args.amp_opt_level}, ignoring.\n" "you should install apex from https://github.com/NVIDIA/apex#quick-start first") args.amp_opt_level = "O0" else: model, optimizer = amp.initialize(model, optimizer, opt_level=args.amp_opt_level) model_ema = amp.initialize(model_ema, opt_level=args.amp_opt_level) model = DistributedDataParallel(model, device_ids=[args.local_rank], broadcast_buffers=False) # optionally resume from a checkpoint if args.resume: assert os.path.isfile(args.resume) load_checkpoint(args, model, model_ema, contrast, optimizer, scheduler) # tensorboard if dist.get_rank() == 0: summary_writer = SummaryWriter(log_dir=args.output_dir) else: summary_writer = None # routine for epoch in range(args.start_epoch, args.epochs + 1): train_loader.sampler.set_epoch(epoch) tic = time.time() loss, prob = train_moco(epoch, train_loader, model, model_ema, contrast, criterion, optimizer, scheduler, args) logger.info('epoch {}, total time {:.2f}'.format(epoch, time.time() - tic)) if summary_writer is not None: # tensorboard logger summary_writer.add_scalar('ins_loss', loss, epoch) summary_writer.add_scalar('ins_prob', prob, epoch) summary_writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch) if dist.get_rank() == 0: # save model save_checkpoint(args, epoch, model, model_ema, contrast, scheduler, optimizer)
def main(args): train_loader = get_loader(args, 'train5/') train_loader_seg = get_loader(args, 'train5_1p_half/') n_data = len(train_loader.dataset) n_data_seg = len(train_loader_seg.dataset) logger.info(f"length of training dataset: {n_data} {n_data_seg}") model, model_ema = build_model(args) if args.model == 'resnet50': contrast = MemoryMoCo(128, 300, args.nce_t).cuda() elif args.model == 'vit': contrast = MemoryMoCo(128, 200, args.nce_t, s=64, c=768).cuda() elif args.model == 'resnet101': contrast = MemoryMoCo(128, 300, args.nce_t).cuda() criterion = NCESoftmaxLoss().cuda() # optimizer = torch.optim.SGD(model.parameters(), # optimizer = torch.optim.SGD([{'params': model.backbone.parameters(), 'lr': args.batch_size / 256 * args.base_learning_rate * 0.8}, optimizer = torch.optim.SGD([{ 'params': model.backbone.parameters() }, { 'params': model.mlp.parameters() }, { 'params': model.mlp2.parameters() }], lr=args.batch_size / 256 * args.base_learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) scheduler = get_scheduler(optimizer, len(train_loader), args) optimizer_seg = torch.optim.SGD( [{ 'params': model.backbone.parameters() }, { 'params': model.decoder.parameters() }, { 'params': model.segmentation_head.parameters() }], lr=args.batch_size / 256 * args.base_learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) scheduler_seg = get_scheduler(optimizer_seg, len(train_loader_seg), args) if args.amp_opt_level != "O0": if amp is None: logger.warning( f"apex is not installed but amp_opt_level is set to {args.amp_opt_level}, ignoring.\n" "you should install apex from https://github.com/NVIDIA/apex#quick-start first" ) args.amp_opt_level = "O0" else: model, optimizer = amp.initialize(model, optimizer, opt_level=args.amp_opt_level) # model_ema = amp.initialize(model_ema, opt_level=args.amp_opt_level) # model = DistributedDataParallel(model, device_ids=[args.local_rank], broadcast_buffers=False) # optionally resume from a checkpoint if args.resume: assert os.path.isfile(args.resume) load_checkpoint(args, model, model_ema, contrast, optimizer, scheduler) # tensorboard summary_writer = SummaryWriter(log_dir=args.output_dir) # routine for epoch in range(args.start_epoch, args.epochs + 1): # train_loader.sampler.set_epoch(epoch) tic = time.time() loss, loss_g, loss_d = train_moco(epoch, train_loader, model, model_ema, contrast, criterion, optimizer, scheduler, args) loss_seg = train_seg(epoch, train_loader_seg, model, optimizer_seg, scheduler_seg, args) logger.info('epoch {}, total time {:.2f}'.format( epoch, time.time() - tic)) if summary_writer is not None: # tensorboard logger summary_writer.add_scalar('ins_loss', loss, epoch) summary_writer.add_scalar('ins_loss_g', loss_g, epoch) summary_writer.add_scalar('ins_loss_d', loss_d, epoch) summary_writer.add_scalar('ins_loss_seg', loss_seg, epoch) summary_writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch) # validate if epoch % args.eval_freq == 0: dsc = inference(model) summary_writer.add_scalar('dice_percase', dsc, epoch) logger.info(f'validate result {epoch}: {dsc}') # save model save_checkpoint(args, epoch, model, model_ema, contrast, optimizer, scheduler)
def main(args): global best_acc1 train_loader, val_loader = get_loader(args) logger.info(f"length of training dataset: {len(train_loader.dataset)}") model, classifier = build_model(args, num_class=len( train_loader.dataset.classes)) criterion = torch.nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(classifier.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) scheduler = get_scheduler(optimizer, len(train_loader), args) if args.amp_opt_level != "O0": if amp is None: logger.warning( f"apex is not installed but amp_opt_level is set to {args.amp_opt_level}, ignoring.\n" "you should install apex from https://github.com/NVIDIA/apex#quick-start first" ) args.amp_opt_level = "O0" else: model = amp.initialize(model, opt_level=args.amp_opt_level) classifier, optimizer = amp.initialize( classifier, optimizer, opt_level=args.amp_opt_level) classifier = DistributedDataParallel(classifier, device_ids=[args.local_rank], broadcast_buffers=False) model.eval() load_pretrained(args, model) # optionally resume from a checkpoint if args.resume: assert os.path.isfile( args.resume), f"no checkpoint found at '{args.resume}'" load_checkpoint(args, classifier, optimizer, scheduler) if args.eval: logger.info("==> testing...") validate(val_loader, model, classifier, criterion, args) return # tensorboard if dist.get_rank() == 0: summary_writer = SummaryWriter(log_dir=args.output_dir) else: summary_writer = None # routine for epoch in range(args.start_epoch, args.epochs + 1): if isinstance(train_loader.sampler, DistributedSampler): train_loader.sampler.set_epoch(epoch) tic = time.time() train(epoch, train_loader, model, classifier, criterion, optimizer, scheduler, args) logger.info(f'epoch {epoch}, total time {time.time() - tic:.2f}') logger.info("==> testing...") test_acc = validate(val_loader, model, classifier, criterion, args) if dist.get_rank() == 0 and epoch % args.save_freq == 0: logger.info('==> Saving...') save_checkpoint(args, epoch, classifier, test_acc, optimizer, scheduler) if summary_writer is not None: # tensorboard logger summary_writer.add_scalar('ins_loss', test_acc, epoch) summary_writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch)