def main(): args = parse_option() if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) # set the data loader data_folder = os.path.join(args.data_folder, 'train') val_folder = os.path.join(args.data_folder, 'val') crop_padding = 32 image_size = 224 mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] normalize = transforms.Normalize(mean=mean, std=std) if args.aug == 'NULL' and args.dataset == 'imagenet': train_transform = transforms.Compose([ transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) elif args.aug == 'CJ': train_transform = transforms.Compose([ transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)), transforms.RandomGrayscale(p=0.2), transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) # elif args.aug == 'NULL' and args.dataset == 'cifar': # train_transform = transforms.Compose([ # transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)), # transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), # transforms.RandomGrayscale(p=0.2), # transforms.RandomHorizontalFlip(p=0.5), # transforms.ToTensor(), # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # ]) # # test_transform = transforms.Compose([ # transforms.ToTensor(), # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # ]) elif args.aug == 'simple' and args.dataset == 'imagenet': train_transform = transforms.Compose([ transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)), transforms.RandomHorizontalFlip(), get_color_distortion(1.0), transforms.ToTensor(), normalize, ]) # TODO: Currently follow CMC test_transform = transforms.Compose([ transforms.Resize(image_size + crop_padding), transforms.CenterCrop(image_size), transforms.ToTensor(), normalize, ]) elif args.aug == 'simple' and args.dataset == 'cifar': train_transform = transforms.Compose([ transforms.RandomResizedCrop(size=32), transforms.RandomHorizontalFlip(p=0.5), get_color_distortion(0.5), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) else: raise NotImplemented('augmentation not supported: {}'.format(args.aug)) # Get Datasets if args.dataset == "imagenet": train_dataset = ImageFolderInstance(data_folder, transform=train_transform, two_crop=args.moco) print(len(train_dataset)) train_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.num_workers, pin_memory=True, sampler=train_sampler) test_dataset = datasets.ImageFolder(val_folder, transforms=test_transform) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=args.num_workers, pin_memory=True) elif args.dataset == 'cifar': # cifar-10 dataset if args.contrastive_model == 'simclr': train_dataset = CIFAR10Instance_double(root='./data', train=True, download=True, transform=train_transform, double=True) else: train_dataset = CIFAR10Instance(root='./data', train=True, download=True, transform=train_transform) train_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.num_workers, pin_memory=True, sampler=train_sampler, drop_last=True) test_dataset = CIFAR10Instance(root='./data', train=False, download=True, transform=test_transform) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=args.num_workers) # create model and optimizer n_data = len(train_dataset) if args.model == 'resnet50': model = InsResNet50() if args.contrastive_model == 'moco': model_ema = InsResNet50() elif args.model == 'resnet50x2': model = InsResNet50(width=2) if args.contrastive_model == 'moco': model_ema = InsResNet50(width=2) elif args.model == 'resnet50x4': model = InsResNet50(width=4) if args.contrastive_model == 'moco': model_ema = InsResNet50(width=4) elif args.model == 'resnet50_cifar': model = InsResNet50_cifar() if args.contrastive_model == 'moco': model_ema = InsResNet50_cifar() else: raise NotImplementedError('model not supported {}'.format(args.model)) # copy weights from `model' to `model_ema' if args.contrastive_model == 'moco': moment_update(model, model_ema, 0) # set the contrast memory and criterion if args.contrastive_model == 'moco': contrast = MemoryMoCo(128, n_data, args.nce_k, args.nce_t, args.softmax).cuda(args.gpu) elif args.contrastive_model == 'simclr': contrast = None else: contrast = MemoryInsDis(128, n_data, args.nce_k, args.nce_t, args.nce_m, args.softmax).cuda(args.gpu) if args.softmax: criterion = NCESoftmaxLoss() elif args.contrastive_model == 'simclr': criterion = BatchCriterion(1, args.nce_t, args.batch_size) else: criterion = NCECriterion(n_data) criterion = criterion.cuda(args.gpu) model = model.cuda() if args.contrastive_model == 'moco': model_ema = model_ema.cuda() # Exclude BN and bias if needed weight_decay = args.weight_decay if weight_decay and args.filter_weight_decay: parameters = add_weight_decay(model, weight_decay, args.filter_weight_decay) weight_decay = 0. else: parameters = model.parameters() optimizer = torch.optim.SGD(parameters, lr=args.learning_rate, momentum=args.momentum, weight_decay=weight_decay) cudnn.benchmark = True if args.amp: model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level) if args.contrastive_model == 'moco': optimizer_ema = torch.optim.SGD(model_ema.parameters(), lr=0, momentum=0, weight_decay=0) model_ema, optimizer_ema = amp.initialize(model_ema, optimizer_ema, opt_level=args.opt_level) if args.LARS: optimizer = LARS(optimizer=optimizer, eps=1e-8, trust_coef=0.001) # optionally resume from a checkpoint args.start_epoch = 0 if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume, map_location='cpu') # checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] + 1 model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) if contrast: contrast.load_state_dict(checkpoint['contrast']) if args.contrastive_model == 'moco': model_ema.load_state_dict(checkpoint['model_ema']) if args.amp and checkpoint['opt'].amp: print('==> resuming amp state_dict') amp.load_state_dict(checkpoint['amp']) print("=> loaded successfully '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) del checkpoint torch.cuda.empty_cache() else: print("=> no checkpoint found at '{}'".format(args.resume)) # tensorboard logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2) # routine for epoch in range(args.start_epoch, args.epochs + 1): print("==> training...") time1 = time.time() if args.contrastive_model == 'moco': loss, prob = train_moco(epoch, train_loader, model, model_ema, contrast, criterion, optimizer, args) elif args.contrastive_model == 'simclr': print("Train using simclr") loss, prob = train_simclr(epoch, train_loader, model, criterion, optimizer, args) else: print("Train using InsDis") loss, prob = train_ins(epoch, train_loader, model, contrast, criterion, optimizer, args) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) # tensorboard logger logger.log_value('ins_loss', loss, epoch) logger.log_value('ins_prob', prob, epoch) logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch) test_epoch = 2 if epoch % test_epoch == 0: model.eval() if args.contrastive_model == 'moco': model_ema.eval() print('----------Evaluation---------') start = time.time() if args.dataset == 'cifar': acc = kNN(epoch, model, train_loader, test_loader, 200, args.nce_t, n_data, low_dim=128, memory_bank=None) print("Evaluation Time: '{}'s".format(time.time() - start)) # writer.add_scalar('nn_acc', acc, epoch) logger.log_value('Test accuracy', acc, epoch) # print('accuracy: {}% \t (best acc: {}%)'.format(acc, best_acc)) print('[Epoch]: {}'.format(epoch)) print('accuracy: {}%)'.format(acc)) # test_log_file.flush() # save model if epoch % args.save_freq == 0: print('==> Saving...') state = { 'opt': args, 'model': model.state_dict(), # 'contrast': contrast.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, } if args.contrastive_model == 'moco': state['model_ema'] = model_ema.state_dict() if args.amp: state['amp'] = amp.state_dict() save_file = os.path.join( args.model_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) torch.save(state, save_file) # help release GPU memory del state # saving the model print('==> Saving...') state = { 'opt': args, 'model': model.state_dict(), # 'contrast': contrast.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, } if args.contrastive_model == 'moco': state['model_ema'] = model_ema.state_dict() if args.amp: state['amp'] = amp.state_dict() save_file = os.path.join(args.model_folder, 'current.pth') torch.save(state, save_file) if epoch % args.save_freq == 0: save_file = os.path.join( args.model_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) torch.save(state, save_file) # help release GPU memory del state torch.cuda.empty_cache()
if P.lr_scheduler == 'cosine': scheduler = lr_scheduler.CosineAnnealingLR(optimizer, P.epochs) elif P.lr_scheduler == 'step_decay': milestones = [int(0.5 * P.epochs), int(0.75 * P.epochs)] scheduler = lr_scheduler.MultiStepLR(optimizer, gamma=lr_decay_gamma, milestones=milestones) else: raise NotImplementedError() from training.scheduler import GradualWarmupScheduler scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=10.0, total_epoch=P.warmup, after_scheduler=scheduler) if P.resume_path is not None: resume = True model_state, optim_state, config = load_checkpoint(P.resume_path, mode='last') model.load_state_dict(model_state, strict=not P.no_strict) optimizer.load_state_dict(optim_state) start_epoch = config['epoch'] best = config['best'] error = 100.0 else: resume = False start_epoch = 1 best = 100.0 error = 100.0 if P.mode == 'sup_linear' or P.mode == 'sup_CSI_linear': assert P.load_path is not None checkpoint = torch.load(P.load_path) model.load_state_dict(checkpoint, strict=not P.no_strict) if P.multi_gpu:
base_optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay_lr) optimizer = LARS(optimizer=base_optimizer, eps=1e-8, trust_coef=0.001) scheduler = ExponentialLR(optimizer, gamma=args.decay_lr) # Main training loop best_loss = np.inf # Resume training if args.load_model is not None: if os.path.isfile(args.load_model): checkpoint = torch.load(args.load_model) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) base_optimizer.load_state_dict(checkpoint['base_optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) best_loss = checkpoint['val_loss'] epoch = checkpoint['epoch'] print('Loading model: {}. Resuming from epoch: {}'.format( args.load_model, epoch)) else: print('Model: {} not found'.format(args.load_model)) for epoch in range(args.epochs): v_loss = execute_graph(model, loader, optimizer, scheduler, epoch, use_cuda) if v_loss < best_loss: best_loss = v_loss
def train(cfg, writer, logger): # Setup random seeds to a determinated value for reproduction # seed = 1337 # torch.manual_seed(seed) # torch.cuda.manual_seed(seed) # np.random.seed(seed) # random.seed(seed) # np.random.default_rng(seed) # Setup Augmentations augmentations = cfg.train.augment logger.info(f'using augments: {augmentations}') data_aug = get_composed_augmentations(augmentations) # Setup Dataloader data_loader = get_loader(cfg.data.dataloader) data_path = cfg.data.path logger.info("Using dataset: {}".format(data_path)) t_loader = data_loader( data_path, # transform=None, # time_shuffle = cfg.data.time_shuffle, # to_tensor=False, data_format = cfg.data.format, split=cfg.data.train_split, norm = cfg.data.norm, augments=data_aug ) v_loader = data_loader( data_path, # transform=None, # time_shuffle = cfg.data.time_shuffle, # to_tensor=False, data_format = cfg.data.format, split=cfg.data.val_split, ) train_data_len = len(t_loader) logger.info(f'num of train samples: {train_data_len} \nnum of val samples: {len(v_loader)}') batch_size = cfg.train.batch_size epoch = cfg.train.epoch train_iter = int(np.ceil(train_data_len / batch_size) * epoch) logger.info(f'total train iter: {train_iter}') trainloader = data.DataLoader(t_loader, batch_size=batch_size, num_workers=cfg.train.n_workers, shuffle=True, persistent_workers=True, drop_last=True) valloader = data.DataLoader(v_loader, batch_size=10, # persis num_workers=cfg.train.n_workers,) # Setup Model device = f'cuda:{cfg.gpu[0]}' model = get_model(cfg.model, 2).to(device) input_size = (cfg.model.input_nbr, 512, 512) logger.info(f"Using Model: {cfg.model.arch}") # logger.info(f'model summary: {summary(model, input_size=(input_size, input_size), is_complex=True)}') model = torch.nn.DataParallel(model, device_ids=cfg.gpu) #自动多卡运行,这个好用 # Setup optimizer, lr_scheduler and loss function optimizer_cls = get_optimizer(cfg) optimizer_params = {k:v for k, v in vars(cfg.train.optimizer).items() if k not in ('name', 'wrap')} optimizer = optimizer_cls(model.parameters(), **optimizer_params) logger.info("Using optimizer {}".format(optimizer)) if hasattr(cfg.train.optimizer, 'warp') and cfg.train.optimizer.wrap=='lars': optimizer = LARS(optimizer=optimizer) logger.info(f'warp optimizer with {cfg.train.optimizer.wrap}') scheduler = get_scheduler(optimizer, cfg.train.lr) loss_fn = get_loss_function(cfg) logger.info(f"Using loss ,{str(cfg.train.loss)}") # load checkpoints val_cls_1_acc = 0 best_cls_1_acc_now = 0 best_cls_1_acc_iter_now = 0 val_macro_OA = 0 best_macro_OA_now = 0 best_macro_OA_iter_now = 0 start_iter = 0 if cfg.train.resume is not None: if os.path.isfile(cfg.train.resume): logger.info( "Loading model and optimizer from checkpoint '{}'".format(cfg.train.resume) ) # load model state checkpoint = torch.load(cfg.train.resume) model.load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) # best_cls_1_acc_now = checkpoint["best_cls_1_acc_now"] # best_cls_1_acc_iter_now = checkpoint["best_cls_1_acc_iter_now"] start_iter = checkpoint["epoch"] logger.info( "Loaded checkpoint '{}' (iter {})".format( cfg.train.resume, checkpoint["epoch"] ) ) # copy tensorboard files resume_src_dir = osp.split(cfg.train.resume)[0] # shutil.copytree(resume_src_dir, writer.get_logdir()) for file in os.listdir(resume_src_dir): if not ('.log' in file or '.yml' in file or '_last_model' in file): # if 'events.out.tfevents' in file: resume_dst_dir = writer.get_logdir() fu.copy(osp.join(resume_src_dir, file), resume_dst_dir, ) else: logger.info("No checkpoint found at '{}'".format(cfg.train.resume)) # Setup Metrics running_metrics_val = runningScore(2) runing_metrics_train = runningScore(2) val_loss_meter = averageMeter() train_time_meter = averageMeter() # train it = start_iter train_start_time = time.time() train_val_start_time = time.time() model.train() while it < train_iter: for (file_a, file_b, label, mask) in trainloader: it += 1 file_a = file_a.to(device) file_b = file_b.to(device) label = label.to(device) mask = mask.to(device) optimizer.zero_grad() # print(f'dtype: {file_a.dtype}') outputs = model(file_a, file_b) loss = loss_fn(input=outputs, target=label, mask=mask) loss.backward() # print('conv11: ', model.conv11.weight.grad, model.conv11.weight.grad.shape) # print('conv21: ', model.conv21.weight.grad, model.conv21.weight.grad.shape) # print('conv31: ', model.conv31.weight.grad, model.conv31.weight.grad.shape) # In PyTorch 1.1.0 and later, you should call `optimizer.step()` before `lr_scheduler.step()` optimizer.step() scheduler.step() # record the acc of the minibatch pred = outputs.max(1)[1].cpu().numpy() runing_metrics_train.update(label.cpu().numpy(), pred, mask.cpu().numpy()) train_time_meter.update(time.time() - train_start_time) if it % cfg.train.print_interval == 0: # acc of the samples between print_interval score, _ = runing_metrics_train.get_scores() train_cls_0_acc, train_cls_1_acc = score['Acc'] fmt_str = "Iter [{:d}/{:d}] train Loss: {:.4f} Time/Image: {:.4f},\n0:{:.4f}\n1:{:.4f}" print_str = fmt_str.format(it, train_iter, loss.item(), #extracts the loss’s value as a Python float. train_time_meter.avg / cfg.train.batch_size,train_cls_0_acc, train_cls_1_acc) runing_metrics_train.reset() train_time_meter.reset() logger.info(print_str) writer.add_scalar('loss/train_loss', loss.item(), it) writer.add_scalars('metrics/train', {'cls_0':train_cls_0_acc, 'cls_1':train_cls_1_acc}, it) # writer.add_scalar('train_metrics/acc/cls_0', train_cls_0_acc, it) # writer.add_scalar('train_metrics/acc/cls_1', train_cls_1_acc, it) if it % cfg.train.val_interval == 0 or \ it == train_iter: val_start_time = time.time() model.eval() # change behavior like drop out with torch.no_grad(): # disable autograd, save memory usage for (file_a_val, file_b_val, label_val, mask_val) in valloader: file_a_val = file_a_val.to(device) file_b_val = file_b_val.to(device) outputs = model(file_a_val, file_b_val) # tensor.max() returns the maximum value and its indices pred = outputs.max(1)[1].cpu().numpy() running_metrics_val.update(label_val.numpy(), pred, mask_val.numpy()) label_val = label_val.to(device) mask_val = mask_val.to(device) val_loss = loss_fn(input=outputs, target=label_val, mask=mask_val) val_loss_meter.update(val_loss.item()) score, _ = running_metrics_val.get_scores() val_cls_0_acc, val_cls_1_acc = score['Acc'] writer.add_scalar('loss/val_loss', val_loss_meter.avg, it) logger.info(f"Iter [{it}/{train_iter}], val Loss: {val_loss_meter.avg:.4f} Time/Image: {(time.time()-val_start_time)/len(v_loader):.4f}\n0: {val_cls_0_acc:.4f}\n1:{val_cls_1_acc:.4f}") # lr_now = optimizer.param_groups[0]['lr'] # logger.info(f'lr: {lr_now}') # writer.add_scalar('lr', lr_now, it+1) logger.info('0: {:.4f}\n1:{:.4f}'.format(val_cls_0_acc, val_cls_1_acc)) writer.add_scalars('metrics/val', {'cls_0':val_cls_0_acc, 'cls_1':val_cls_1_acc}, it) # writer.add_scalar('val_metrics/acc/cls_0', val_cls_0_acc, it) # writer.add_scalar('val_metrics/acc/cls_1', val_cls_1_acc, it) val_loss_meter.reset() running_metrics_val.reset() # OA=score["Overall_Acc"] val_macro_OA = (val_cls_0_acc+val_cls_1_acc)/2 if val_macro_OA >= best_macro_OA_now and it>200: best_macro_OA_now = val_macro_OA best_macro_OA_iter_now = it state = { "epoch": it, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_macro_OA_now": best_macro_OA_now, 'best_macro_OA_iter_now':best_macro_OA_iter_now, } save_path = os.path.join(writer.file_writer.get_logdir(), "{}_{}_best_model.pkl".format(cfg.model.arch,cfg.data.dataloader)) torch.save(state, save_path) logger.info("best OA now = %.8f" % (best_macro_OA_now)) logger.info("best OA iter now= %d" % (best_macro_OA_iter_now)) train_val_time = time.time() - train_val_start_time remain_time = train_val_time * (train_iter-it) / it m, s = divmod(remain_time, 60) h, m = divmod(m, 60) if s != 0: train_time = "Remain train time = %d hours %d minutes %d seconds \n" % (h, m, s) else: train_time = "Remain train time : train completed.\n" logger.info(train_time) model.train() train_start_time = time.time() logger.info("best OA now = %.8f" % (best_macro_OA_now)) logger.info("best OA iter now= %d" % (best_macro_OA_iter_now)) state = { "epoch": it, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_macro_OA_now": best_macro_OA_now, 'best_macro_OA_iter_now':best_macro_OA_iter_now, } save_path = os.path.join(writer.file_writer.get_logdir(), "{}_{}_last_model.pkl".format(cfg.model.arch, cfg.data.dataloader)) torch.save(state, save_path)
def train_and_eval(tag, dataroot, test_ratio=0.0, cv_fold=0, reporter=None, metric='last', save_path=None, only_eval=False): if not reporter: reporter = lambda **kwargs: 0 max_epoch = C.get()['epoch'] trainsampler, trainloader, validloader, testloader_ = get_dataloaders( C.get()['dataset'], C.get()['batch'], dataroot, test_ratio, split_idx=cv_fold) # create a model & an optimizer model = get_model(C.get()['model'], num_class(C.get()['dataset'])) lb_smooth = C.get()['optimizer'].get('label_smoothing', 0.0) if lb_smooth > 0.0: criterion = SmoothCrossEntropyLoss(lb_smooth) else: criterion = nn.CrossEntropyLoss() if C.get()['optimizer']['type'] == 'sgd': optimizer = optim.SGD(model.parameters(), lr=C.get()['lr'], momentum=C.get()['optimizer'].get( 'momentum', 0.9), weight_decay=C.get()['optimizer']['decay'], nesterov=C.get()['optimizer']['nesterov']) else: raise ValueError('invalid optimizer type=%s' % C.get()['optimizer']['type']) if C.get()['optimizer'].get('lars', False): from torchlars import LARS optimizer = LARS(optimizer) logger.info('*** LARS Enabled.') lr_scheduler_type = C.get()['lr_schedule'].get('type', 'cosine') if lr_scheduler_type == 'cosine': scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=C.get()['epoch'], eta_min=0.) elif lr_scheduler_type == 'resnet': scheduler = adjust_learning_rate_resnet(optimizer) else: raise ValueError('invalid lr_schduler=%s' % lr_scheduler_type) if C.get()['lr_schedule'].get('warmup', None): scheduler = GradualWarmupScheduler( optimizer, multiplier=C.get()['lr_schedule']['warmup']['multiplier'], total_epoch=C.get()['lr_schedule']['warmup']['epoch'], after_scheduler=scheduler) if not tag: from RandAugment.metrics import SummaryWriterDummy as SummaryWriter logger.warning('tag not provided, no tensorboard log.') else: from tensorboardX import SummaryWriter writers = [ SummaryWriter(log_dir='./logs/%s/%s' % (tag, x)) for x in ['train', 'valid', 'test'] ] result = OrderedDict() epoch_start = 1 if save_path and os.path.exists(save_path): logger.info('%s file found. loading...' % save_path) data = torch.load(save_path) if 'model' in data or 'state_dict' in data: key = 'model' if 'model' in data else 'state_dict' logger.info('checkpoint epoch@%d' % data['epoch']) if not isinstance(model, DataParallel): model.load_state_dict({ k.replace('module.', ''): v for k, v in data[key].items() }) else: model.load_state_dict({ k if 'module.' in k else 'module.' + k: v for k, v in data[key].items() }) optimizer.load_state_dict(data['optimizer']) if data['epoch'] < C.get()['epoch']: epoch_start = data['epoch'] else: only_eval = True else: model.load_state_dict({k: v for k, v in data.items()}) del data else: logger.info('"%s" file not found. skip to pretrain weights...' % save_path) if only_eval: logger.warning( 'model checkpoint not found. only-evaluation mode is off.') only_eval = False if only_eval: logger.info('evaluation only+') model.eval() rs = dict() rs['train'] = run_epoch(model, trainloader, criterion, None, desc_default='train', epoch=0, writer=writers[0]) rs['valid'] = run_epoch(model, validloader, criterion, None, desc_default='valid', epoch=0, writer=writers[1]) rs['test'] = run_epoch(model, testloader_, criterion, None, desc_default='*test', epoch=0, writer=writers[2]) for key, setname in itertools.product(['loss', 'top1', 'top5'], ['train', 'valid', 'test']): if setname not in rs: continue result['%s_%s' % (key, setname)] = rs[setname][key] result['epoch'] = 0 return result # train loop best_top1 = 0 for epoch in range(epoch_start, max_epoch + 1): model.train() rs = dict() rs['train'] = run_epoch(model, trainloader, criterion, optimizer, desc_default='train', epoch=epoch, writer=writers[0], verbose=True, scheduler=scheduler) model.eval() if math.isnan(rs['train']['loss']): raise Exception('train loss is NaN.') if epoch % 5 == 0 or epoch == max_epoch: rs['valid'] = run_epoch(model, validloader, criterion, None, desc_default='valid', epoch=epoch, writer=writers[1], verbose=True) rs['test'] = run_epoch(model, testloader_, criterion, None, desc_default='*test', epoch=epoch, writer=writers[2], verbose=True) if metric == 'last' or rs[metric]['top1'] > best_top1: if metric != 'last': best_top1 = rs[metric]['top1'] for key, setname in itertools.product( ['loss', 'top1', 'top5'], ['train', 'valid', 'test']): result['%s_%s' % (key, setname)] = rs[setname][key] result['epoch'] = epoch writers[1].add_scalar('valid_top1/best', rs['valid']['top1'], epoch) writers[2].add_scalar('test_top1/best', rs['test']['top1'], epoch) reporter(loss_valid=rs['valid']['loss'], top1_valid=rs['valid']['top1'], loss_test=rs['test']['loss'], top1_test=rs['test']['top1']) # save checkpoint if save_path: logger.info('save model@%d to %s' % (epoch, save_path)) torch.save( { 'epoch': epoch, 'log': { 'train': rs['train'].get_dict(), 'valid': rs['valid'].get_dict(), 'test': rs['test'].get_dict(), }, 'optimizer': optimizer.state_dict(), 'model': model.state_dict() }, save_path) torch.save( { 'epoch': epoch, 'log': { 'train': rs['train'].get_dict(), 'valid': rs['valid'].get_dict(), 'test': rs['test'].get_dict(), }, 'optimizer': optimizer.state_dict(), 'model': model.state_dict() }, save_path.replace( '.pth', '_e%d_top1_%.3f_%.3f' % (epoch, rs['train']['top1'], rs['test']['top1']) + '.pth')) del model result['top1_test'] = best_top1 return result
def train(cfg, writer, logger): # Setup Augmentations augmentations = cfg.train.augment logger.info(f'using augments: {augmentations}') data_aug = get_composed_augmentations(augmentations) # Setup Dataloader data_loader = get_loader(cfg.data.dataloader) data_path = cfg.data.path logger.info("data path: {}".format(data_path)) t_loader = data_loader( data_path, data_format=cfg.data.format, norm=cfg.data.norm, split='train', split_root=cfg.data.split, augments=data_aug, logger=logger, log=cfg.data.log, ENL=cfg.data.ENL, ) v_loader = data_loader( data_path, data_format=cfg.data.format, split='val', log=cfg.data.log, split_root=cfg.data.split, logger=logger, ENL=cfg.data.ENL, ) train_data_len = len(t_loader) logger.info( f'num of train samples: {train_data_len} \nnum of val samples: {len(v_loader)}' ) batch_size = cfg.train.batch_size epoch = cfg.train.epoch train_iter = int(np.ceil(train_data_len / batch_size) * epoch) logger.info(f'total train iter: {train_iter}') trainloader = data.DataLoader(t_loader, batch_size=batch_size, num_workers=cfg.train.n_workers, shuffle=True, persistent_workers=True, drop_last=True) valloader = data.DataLoader( v_loader, batch_size=cfg.test.batch_size, # persis num_workers=cfg.train.n_workers, ) # Setup Model device = f'cuda:{cfg.train.gpu[0]}' model = get_model(cfg.model).to(device) input_size = (cfg.model.in_channels, 512, 512) logger.info(f"Using Model: {cfg.model.arch}") # logger.info(f'model summary: {summary(model, input_size=(input_size, input_size), is_complex=False)}') model = torch.nn.DataParallel(model, device_ids=cfg.gpu) #自动多卡运行,这个好用 # Setup optimizer, lr_scheduler and loss function optimizer_cls = get_optimizer(cfg) optimizer_params = { k: v for k, v in vars(cfg.train.optimizer).items() if k not in ('name', 'wrap') } optimizer = optimizer_cls(model.parameters(), **optimizer_params) logger.info("Using optimizer {}".format(optimizer)) if hasattr(cfg.train.optimizer, 'wrap') and cfg.train.optimizer.wrap == 'lars': optimizer = LARS(optimizer=optimizer) logger.info(f'warp optimizer with {cfg.train.optimizer.wrap}') scheduler = get_scheduler(optimizer, cfg.train.lr) # loss_fn = get_loss_function(cfg) # logger.info(f"Using loss ,{str(cfg.train.loss)}") # load checkpoints val_cls_1_acc = 0 best_cls_1_acc_now = 0 best_cls_1_acc_iter_now = 0 val_macro_OA = 0 best_macro_OA_now = 0 best_macro_OA_iter_now = 0 start_iter = 0 if cfg.train.resume is not None: if os.path.isfile(cfg.train.resume): logger.info( "Loading model and optimizer from checkpoint '{}'".format( cfg.train.resume)) # load model state checkpoint = torch.load(cfg.train.resume) model.load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) # best_cls_1_acc_now = checkpoint["best_cls_1_acc_now"] # best_cls_1_acc_iter_now = checkpoint["best_cls_1_acc_iter_now"] start_iter = checkpoint["epoch"] logger.info("Loaded checkpoint '{}' (iter {})".format( cfg.train.resume, checkpoint["epoch"])) # copy tensorboard files resume_src_dir = osp.split(cfg.train.resume)[0] # shutil.copytree(resume_src_dir, writer.get_logdir()) for file in os.listdir(resume_src_dir): if not ('.log' in file or '.yml' in file or '_last_model' in file): # if 'events.out.tfevents' in file: resume_dst_dir = writer.get_logdir() fu.copy( osp.join(resume_src_dir, file), resume_dst_dir, ) else: logger.info("No checkpoint found at '{}'".format(cfg.train.resume)) data_range = 255 if cfg.data.log: data_range = np.log(data_range) # data_range /= 350 # Setup Metrics running_metrics_val = runningScore(2) runing_metrics_train = runningScore(2) val_loss_meter = averageMeter() train_time_meter = averageMeter() train_loss_meter = averageMeter() val_psnr_meter = averageMeter() val_ssim_meter = averageMeter() # train it = start_iter train_start_time = time.time() train_val_start_time = time.time() model.train() while it < train_iter: for clean, noisy, _ in trainloader: it += 1 noisy = noisy.to(device, dtype=torch.float32) # noisy /= 350 mask1, mask2 = rand_pool.generate_mask_pair(noisy) noisy_sub1 = rand_pool.generate_subimages(noisy, mask1) noisy_sub2 = rand_pool.generate_subimages(noisy, mask2) # preparing for the regularization term with torch.no_grad(): noisy_denoised = model(noisy) noisy_sub1_denoised = rand_pool.generate_subimages( noisy_denoised, mask1) noisy_sub2_denoised = rand_pool.generate_subimages( noisy_denoised, mask2) # print(rand_pool.operation_seed_counter) # for ii, param in enumerate(model.parameters()): # if torch.sum(torch.isnan(param.data)): # print(f'{ii}: nan parameters') # calculating the loss noisy_output = model(noisy_sub1) noisy_target = noisy_sub2 if cfg.train.loss.gamma.const: gamma = cfg.train.loss.gamma.base else: gamma = it / train_iter * cfg.train.loss.gamma.base diff = noisy_output - noisy_target exp_diff = noisy_sub1_denoised - noisy_sub2_denoised loss1 = torch.mean(diff**2) loss2 = gamma * torch.mean((diff - exp_diff)**2) loss_all = loss1 + loss2 # loss1 = noisy_output - noisy_target # loss2 = torch.exp(noisy_target - noisy_output) # loss_all = torch.mean(loss1 + loss2) loss_all.backward() # In PyTorch 1.1.0 and later, you should call `optimizer.step()` before `lr_scheduler.step()` optimizer.step() scheduler.step() # record the loss of the minibatch train_loss_meter.update(loss_all) train_time_meter.update(time.time() - train_start_time) writer.add_scalar('lr', optimizer.param_groups[0]['lr'], it) if it % 1000 == 0: writer.add_histogram('hist/pred', noisy_denoised, it) writer.add_histogram('hist/noisy', noisy, it) if cfg.data.simulate: writer.add_histogram('hist/clean', clean, it) if cfg.data.simulate: pass # print interval if it % cfg.train.print_interval == 0: terminal_info = f"Iter [{it:d}/{train_iter:d}] \ train Loss: {train_loss_meter.avg:.4f} \ Time/Image: {train_time_meter.avg / cfg.train.batch_size:.4f}" logger.info(terminal_info) writer.add_scalar('loss/train_loss', train_loss_meter.avg, it) if cfg.data.simulate: pass runing_metrics_train.reset() train_time_meter.reset() train_loss_meter.reset() # val interval if it % cfg.train.val_interval == 0 or \ it == train_iter: val_start_time = time.time() model.eval() with torch.no_grad(): for clean, noisy, _ in valloader: # noisy /= 350 # clean /= 350 noisy = noisy.to(device, dtype=torch.float32) noisy_denoised = model(noisy) if cfg.data.simulate: clean = clean.to(device, dtype=torch.float32) psnr = piq.psnr(clean, noisy_denoised, data_range=data_range) ssim = piq.ssim(clean, noisy_denoised, data_range=data_range) val_psnr_meter.update(psnr) val_ssim_meter.update(ssim) val_loss = torch.mean((noisy_denoised - noisy)**2) val_loss_meter.update(val_loss) writer.add_scalar('loss/val_loss', val_loss_meter.avg, it) logger.info( f"Iter [{it}/{train_iter}], val Loss: {val_loss_meter.avg:.4f} Time/Image: {(time.time()-val_start_time)/len(v_loader):.4f}" ) val_loss_meter.reset() running_metrics_val.reset() if cfg.data.simulate: writer.add_scalars('metrics/val', { 'psnr': val_psnr_meter.avg, 'ssim': val_ssim_meter.avg }, it) logger.info( f'psnr: {val_psnr_meter.avg},\tssim: {val_ssim_meter.avg}' ) val_psnr_meter.reset() val_ssim_meter.reset() train_val_time = time.time() - train_val_start_time remain_time = train_val_time * (train_iter - it) / it m, s = divmod(remain_time, 60) h, m = divmod(m, 60) if s != 0: train_time = "Remain train time = %d hours %d minutes %d seconds \n" % ( h, m, s) else: train_time = "Remain train time : train completed.\n" logger.info(train_time) model.train() # save model if it % (train_iter / cfg.train.epoch * 10) == 0: ep = int(it / (train_iter / cfg.train.epoch)) state = { "epoch": it, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), } save_path = osp.join(writer.file_writer.get_logdir(), f"{ep}.pkl") torch.save(state, save_path) logger.info(f'saved model state dict at {save_path}') train_start_time = time.time()
def main_worker(gpu, ngpus_per_node, args): args.gpu = gpu if args.distributed: if args.dist_url == 'env://' and args.rank == -1: args.rank = int(os.environ["RANK"]) if args.multiprocessing_distributed: args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) model = PixPro( encoder=resnet50, dim1 = args.pcl_dim_1, dim2 = args.pcl_dim_2, momentum = args.encoder_momentum, threshold = args.threshold, temperature = args.T, sharpness = args.sharpness , num_linear = args.num_linear, ) args.lr = args.lr_base * args.batch_size/256 if args.distributed: if args.gpu is not None: torch.cuda.set_device(args.gpu) model.cuda(args.gpu) args.batch_size = int(args.batch_size / ngpus_per_node) args.workers = int((args.workers + ngpus_per_node -1) / ngpus_per_node) # convert batch norm --> sync batch norm sync_bn_model = nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) else: model.cuda() model = torch.nn.parallel.DistributedDataParallel(moel) elif args.gpu is not None: torch.cuda.set_device(args.gpu) model = model.cuda(args.gpu) else: raise NotImplementedError('only DDP is supported.') base_optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer = LARS(optimizer=base_optimizer, eps=1e-8) writer = SummaryWriter(args.log_dir) if args.resume: checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) cudnn.benchmark = True dataset = PixProDataset(root=args.train_path, args=args) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) else: train_sampler = None loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) adjust_lr(optimizer, epoch, args) train(args, epoch, loader, model, optimizer, writer) if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): save_name = '{}.pth.tar'.format(epoch) save_name = os.path.join(args.checkpoint_dir, save_name) torch.save({ 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict(), }, save_name)