def modelDeploy(args, model, optimizer, scheduler, logger): if args.num_gpus >= 1: from torch.nn.parallel import DataParallel model = DataParallel(model) model = model.cuda() if torch.backends.cudnn.is_available(): import torch.backends.cudnn as cudnn cudnn.benchmark = True cudnn.deterministic = True trainData = {'epoch': 0, 'loss': [], 'miou': [], 'val': [], 'bestMiou': 0} if args.resume: if os.path.isfile(args.resume): logger.info("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume, map_location=torch.device('cpu')) # model&optimizer model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) # stop point trainData = checkpoint['trainData'] for i in range(trainData['epoch']): scheduler.step() # print(trainData) logger.info("=> loaded checkpoint '{}' (epoch {})".format( args.resume, trainData['epoch'])) else: logger.error("=> no checkpoint found at '{}'".format(args.resume)) assert False, "=> no checkpoint found at '{}'".format(args.resume) if args.finetune: if os.path.isfile(args.finetune): logger.info("=> finetuning checkpoint '{}'".format(args.finetune)) state_all = torch.load(args.finetune, map_location='cpu')['model'] state_clip = {} # only use backbone parameters # print(model.state_dict().keys()) for k, v in state_all.items(): state_clip[k] = v # print(state_clip.keys()) model.load_state_dict(state_clip, strict=False) else: logger.warning("finetune is not a file.") pass if args.freeze_bn: logger.warning('Freezing batch normalization layers') for m in model.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() m.weight.requires_grad = False m.bias.requires_grad = False return model, trainData
def load_reid_model(): model = DataParallel(Model()) ckpt = '/home/honglongcai/Github/PretrainedModel/model_410.pt' model.load_state_dict(torch.load(ckpt, map_location='cuda')) logger.info('Load ReID model from {}'.format(ckpt)) model = model.cuda() model.eval() return model
class TestProcess: def __init__(self): self.net = ET_Net() if (ARGS['gpu']): self.net = DataParallel(module=self.net.cuda()) self.net.load_state_dict(torch.load(ARGS['weight'])) self.test_dataset = get_dataset(dataset_name=ARGS['dataset'], part='test') def predict(self): start = time.time() self.net.eval() test_dataloader = DataLoader(self.test_dataset, batch_size=1) # only support batch size = 1 os.makedirs(ARGS['prediction_save_folder'], exist_ok=True) for items in test_dataloader: images, mask, filename = items['image'], items['mask'], items['filename'] images = images.float() mask = mask.long() print('image shape:', images.size()) image_patches, big_h, big_w = get_test_patches(images, ARGS['crop_size'], ARGS['stride_size']) test_patch_dataloader = DataLoader(image_patches, batch_size=ARGS['batch_size'], shuffle=False, drop_last=False) test_results = [] print('Number of batches for testing:', len(test_patch_dataloader)) for patches in test_patch_dataloader: if ARGS['gpu']: patches = patches.cuda() with torch.no_grad(): result_patches_edge, result_patches = self.net(patches) test_results.append(result_patches.cpu()) test_results = torch.cat(test_results, dim=0) # merge test_results = recompone_overlap(test_results, ARGS['crop_size'], ARGS['stride_size'], big_h, big_w) test_results = test_results[:, 1, :images.size(2), :images.size(3)] * mask test_results = Image.fromarray(test_results[0].numpy()) test_results.save(os.path.join(ARGS['prediction_save_folder'], filename[0])) print(f'Finish prediction for {filename[0]}') finish = time.time() print('Predicting time consumed: {:.2f}s'.format(finish - start))
class BaseInferencer(object): def __init__(self, model, images_path, labels_path, patient_ids, sample_shape, checkpoint_restore, inference_dir, use_gpu=False, gpu_ids=None): # model settings self.model = model # data settings assert len(images_path) == len(labels_path) self.images_path = images_path self.labels_path = labels_path self.patient_ids = patient_ids self.length = len(images_path) self.sample_shape = sample_shape self.inference_dir = inference_dir # gpu settings self.use_gpu = use_gpu if use_gpu and torch.cuda.device_count() > 0: self.model.cuda() if gpu_ids is not None: if len(gpu_ids) > 1: self.multi_gpu = True self.model = DataParallel(model, gpu_ids) else: self.multi_gpu = False else: if torch.cuda.device_count() > 1: self.multi_gpu = True self.model = DataParallel(model) else: self.multi_gpu = False else: self.multi_gpu = False self.model = self.model.cpu() self.model.load_state_dict(torch.load(checkpoint_restore)) def inference(self): self.model.eval() logging.info('*' * 80) logging.info('start inference loop') logging.info('%d patients need to be inference ' % self.length) for index in range(self.length): logging.info('start inference %d-th patient' % (index + 1)) self.__inference__(index) logging.info('*' * 80) logging.info('inference patient: %d' % self.length) logging.info('inference result saved in: %s' % self.inference_dir) def __inference__(self, index): """ :rtype: object """ pass
class BaseEvaluation(object): def __init__(self, model, metrics, images_path, labels_path, sample_shape, checkpoint_restore, use_gpu=False, gpu_ids=None): # model settings self.model = model # metrics settings assert type(metrics) == dict self.metrics = metrics # data settings assert len(images_path) == len(labels_path) self.images_path = images_path self.labels_path = labels_path self.length = len(images_path) self.sample_shape = sample_shape # gpu settings self.use_gpu = use_gpu if use_gpu and torch.cuda.device_count() > 0: self.model.cuda() if gpu_ids is not None: if len(gpu_ids) > 1: self.multi_gpu = True self.model = DataParallel(model, gpu_ids) else: self.multi_gpu = False else: if torch.cuda.device_count() > 1: self.multi_gpu = True self.model = DataParallel(model) else: self.multi_gpu = False else: self.multi_gpu = False self.model = self.model.cpu() self.model.load_state_dict(torch.load(checkpoint_restore)) def load_data(self, index): """ :rtype: image -> nd-array, label -> nd-array """ pass def eval_one_patient(self, image, label): """ :rtype: metrics -> dict """ pass def eval(self): self.model.eval() logging.info('*' * 80) logging.info('start evaluation loop') logging.info('%d patients need to be evaluated ' % self.length) result = dict() for index in range(self.length): logging.info('start evaluation %d-th patient' % (index + 1)) image, label = self.load_data(index) with torch.no_grad(): metrics = self.eval_one_patient(image, label) for key in metrics.keys(): if key not in result.keys(): result[key] = list() result[key].append(metrics[key]) logging.info('evaluation metrics result: %s' % str(metrics)) mean_result = dict() for key in result.keys(): mean_result[key] = np.mean(result[key]) logging.info('*' * 80) logging.info('evaluation report: ') logging.info('evaluation patient: %d' % self.length) logging.info('evaluation metrics %s' % str(mean_result))
class BaseTrainer(object): def __init__(self, epochs, model, train_dataloader, train_loss_func, train_metrics_func, optimizer, log_dir, checkpoint_dir, checkpoint_frequency, checkpoint_restore=None, val_dataloader=None, val_metrics_func=None, lr_scheduler=None, lr_reduce_metric=None, use_gpu=False, gpu_ids=None): # train settings self.epochs = epochs self.model = model self.train_dataloader = train_dataloader self.train_loss_func = train_loss_func self.train_metrics_func = train_metrics_func self.optimizer = optimizer self.checkpoint_dir = checkpoint_dir self.checkpoint_frequency = checkpoint_frequency self.writer = SummaryWriter(logdir=log_dir) # validation settings if val_dataloader is not None: self.validation = True self.val_dataloader = val_dataloader self.val_metrics_func = val_metrics_func else: self.validation = False # lr scheduler settings if lr_scheduler is not None: self.lr_schedule = True self.lr_scheduler = lr_scheduler if isinstance(lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): self.lr_reduce_metric = lr_reduce_metric else: self.lr_schedule = False # multi-gpu settings self.use_gpu = use_gpu gpu_visible = list() for index in range(len(gpu_ids)): gpu_visible.append(str(gpu_ids[index])) os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(gpu_visible) if use_gpu and torch.cuda.device_count() > 0: self.model.cuda() if gpu_ids is not None: if len(gpu_ids) > 1: self.multi_gpu = True self.model = DataParallel(model, gpu_ids) else: self.multi_gpu = False else: if torch.cuda.device_count() > 1: self.multi_gpu = True self.model = DataParallel(model) else: self.multi_gpu = False else: self.multi_gpu = False self.device = torch.device('cpu') self.model = self.model.cpu() # checkpoint settings if checkpoint_restore is not None: self.model.load_state_dict(torch.load(checkpoint_restore)) def train(self): for epoch in range(1, self.epochs + 1): logging.info('*' * 80) logging.info('start epoch %d training loop' % epoch) # train self.model.train() loss, metrics = self.train_epochs(epoch) self.writer.add_scalar('train_loss', loss, epoch) for key in metrics.keys(): self.writer.add_scalar(key, metrics[key], epoch) if self.lr_schedule: if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): self.lr_scheduler.step(loss[self.lr_reduce_metric]) else: self.lr_scheduler.step() logging.info('train loss result: %s' % str(loss)) logging.info('train metrics result: %s' % str(metrics)) # validation if self.validation: logging.info('validation start ... ') self.model.eval() loss, metrics = self.val_epochs(epoch) self.writer.add_scalar('val_loss', loss, epoch) for key in metrics.keys(): self.writer.add_scalar(key, metrics[key], epoch) logging.info('validation loss result: %s' % str(loss)) logging.info('validation metrics result: %s' % str(metrics)) # model checkpoint if epoch % self.checkpoint_frequency == 0: logging.info('saving model...') checkpoint_name = 'checkpoint_%d.pth' % epoch if self.multi_gpu: torch.save( self.model.module.state_dict(), os.path.join(self.checkpoint_dir, checkpoint_name)) else: torch.save( self.model.state_dict(), os.path.join(self.checkpoint_dir, checkpoint_name)) logging.info('model have saved for epoch_%d ' % epoch) else: logging.info('saving model skipped.') def train_epochs(self, epoch) -> (dict, dict): """ :rtype: loss -> dict , metrics -> dict """ pass def val_epochs(self, epoch) -> (dict, dict): """ :rtype: loss -> dict , metrics -> dict """ pass
'params': IDE.classifier.parameters(), 'lr': 0.01 }], momentum=0.9, weight_decay=5e-4, nesterov=True) # Decay LR by a factor of 0.1 every 20 epochs (20 epochs for market and 30 epochs for duke) scheduler_IDE = lr_scheduler.StepLR(IDE_optimizer, step_size=10, gamma=0.1) ## load checkpoint ckpt_dir = './checkpoints/espgan_m2d_lam5/' utils.mkdir(ckpt_dir) try: ckpt = utils.load_checkpoint(ckpt_dir, map_location=torch.device('cpu')) start_epoch = ckpt['epoch'] Da.load_state_dict(ckpt['Da']) Db.load_state_dict(ckpt['Db']) Ga.load_state_dict(ckpt['Ga']) Gb.load_state_dict(ckpt['Gb']) IDE.load_state_dict(ckpt['IDE']) da_optimizer.load_state_dict(ckpt['da_optimizer']) db_optimizer.load_state_dict(ckpt['db_optimizer']) ga_optimizer.load_state_dict(ckpt['ga_optimizer']) gb_optimizer.load_state_dict(ckpt['gb_optimizer']) IDE_optimizer.load_state_dict(ckpt['IDE_optimizer']) except: start_epoch = 0 print('Training form zero') ## run
class SSLOnlineEvaluator(Callback): # pragma: no cover """Attaches a MLP for fine-tuning using the standard self-supervised protocol. Example:: # your datamodule must have 2 attributes dm = DataModule() dm.num_classes = ... # the num of classes in the datamodule dm.name = ... # name of the datamodule (e.g. ImageNet, STL10, CIFAR10) # your model must have 1 attribute model = Model() model.z_dim = ... # the representation dim online_eval = SSLOnlineEvaluator( z_dim=model.z_dim ) """ def __init__( self, z_dim: int, drop_p: float = 0.2, hidden_dim: Optional[int] = None, num_classes: Optional[int] = None, dataset: Optional[str] = None, ): """ Args: z_dim: Representation dimension drop_p: Dropout probability hidden_dim: Hidden dimension for the fine-tune MLP """ super().__init__() self.z_dim = z_dim self.hidden_dim = hidden_dim self.drop_p = drop_p self.optimizer: Optional[Optimizer] = None self.online_evaluator: Optional[SSLEvaluator] = None self.num_classes: Optional[int] = None self.dataset: Optional[str] = None self.num_classes: Optional[int] = num_classes self.dataset: Optional[str] = dataset self._recovered_callback_state: Optional[Dict[str, Any]] = None def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: if self.num_classes is None: self.num_classes = trainer.datamodule.num_classes if self.dataset is None: self.dataset = trainer.datamodule.name def on_pretrain_routine_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # must move to device after setup, as during setup, pl_module is still on cpu self.online_evaluator = SSLEvaluator( n_input=self.z_dim, n_classes=self.num_classes, p=self.drop_p, n_hidden=self.hidden_dim, ).to(pl_module.device) # switch fo PL compatibility reasons accel = (trainer.accelerator_connector if hasattr( trainer, "accelerator_connector") else trainer._accelerator_connector) if accel.is_distributed: if accel.use_ddp: from torch.nn.parallel import DistributedDataParallel as DDP self.online_evaluator = DDP(self.online_evaluator, device_ids=[pl_module.device]) elif accel.use_dp: from torch.nn.parallel import DataParallel as DP self.online_evaluator = DP(self.online_evaluator, device_ids=[pl_module.device]) else: rank_zero_warn( "Does not support this type of distributed accelerator. The online evaluator will not sync." ) self.optimizer = torch.optim.Adam(self.online_evaluator.parameters(), lr=1e-4) if self._recovered_callback_state is not None: self.online_evaluator.load_state_dict( self._recovered_callback_state["state_dict"]) self.optimizer.load_state_dict( self._recovered_callback_state["optimizer_state"]) def to_device(self, batch: Sequence, device: Union[str, torch.device]) -> Tuple[Tensor, Tensor]: # get the labeled batch if self.dataset == "stl10": labeled_batch = batch[1] batch = labeled_batch inputs, y = batch # last input is for online eval x = inputs[-1] x = x.to(device) y = y.to(device) return x, y def shared_step( self, pl_module: LightningModule, batch: Sequence, ): with torch.no_grad(): with set_training(pl_module, False): x, y = self.to_device(batch, pl_module.device) representations = pl_module(x).flatten(start_dim=1) # forward pass mlp_logits = self.online_evaluator( representations) # type: ignore[operator] mlp_loss = F.cross_entropy(mlp_logits, y) acc = accuracy(mlp_logits.softmax(-1), y) return acc, mlp_loss def on_train_batch_end( self, trainer: Trainer, pl_module: LightningModule, outputs: Sequence, batch: Sequence, batch_idx: int, dataloader_idx: int, ) -> None: train_acc, mlp_loss = self.shared_step(pl_module, batch) # update finetune weights mlp_loss.backward() self.optimizer.step() self.optimizer.zero_grad() pl_module.log("online_train_acc", train_acc, on_step=True, on_epoch=False) pl_module.log("online_train_loss", mlp_loss, on_step=True, on_epoch=False) def on_validation_batch_end( self, trainer: Trainer, pl_module: LightningModule, outputs: Sequence, batch: Sequence, batch_idx: int, dataloader_idx: int, ) -> None: val_acc, mlp_loss = self.shared_step(pl_module, batch) pl_module.log("online_val_acc", val_acc, on_step=False, on_epoch=True, sync_dist=True) pl_module.log("online_val_loss", mlp_loss, on_step=False, on_epoch=True, sync_dist=True) def on_save_checkpoint(self, trainer: Trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]) -> dict: return { "state_dict": self.online_evaluator.state_dict(), "optimizer_state": self.optimizer.state_dict() } def on_load_checkpoint(self, trainer: Trainer, pl_module: LightningModule, callback_state: Dict[str, Any]) -> None: self._recovered_callback_state = callback_state
def main(): global args, best_prec1 args = parser.parse_args() # Read list of training and validation data listfiles_train, labels_train = read_lists(TRAIN_OUT) listfiles_val, labels_val = read_lists(VAL_OUT) listfiles_test, labels_test = read_lists(TEST_OUT) dataset_train = Dataset(listfiles_train, labels_train, subtract_mean=False, V=12) dataset_val = Dataset(listfiles_val, labels_val, subtract_mean=False, V=12) dataset_test = Dataset(listfiles_test, labels_test, subtract_mean=False, V=12) # shuffle data dataset_train.shuffle() dataset_val.shuffle() dataset_test.shuffle() tra_data_size, val_data_size, test_data_size = dataset_train.size( ), dataset_val.size(), dataset_test.size() print 'training size:', tra_data_size print 'validation size:', val_data_size print 'testing size:', test_data_size batch_size = args.b print("batch_size is :" + str(batch_size)) learning_rate = args.lr print("learning_rate is :" + str(learning_rate)) num_cuda = cuda.device_count() print("number of GPUs have been detected:" + str(num_cuda)) # creat model print("model building...") mvcnn = DataParallel(modelnet40_Alex(num_cuda, batch_size)) #mvcnn = modelnet40(num_cuda, batch_size, multi_gpu = False) mvcnn.cuda() # Optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint'{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] mvcnn.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) #print(mvcnn) criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.Adadelta(mvcnn.parameters(), weight_decay=1e-4) # evaluate performance only if args.evaluate: print 'testing mode ------------------' validate(dataset_test, mvcnn, criterion, optimizer, batch_size) return print 'training mode ------------------' for epoch in xrange(args.start_epoch, args.epochs): print('epoch:', epoch) #adjust_learning_rate(optimizer, epoch) # train for one epoch train(dataset_train, mvcnn, criterion, optimizer, epoch, batch_size) # evaluate on validation set prec1 = validate(dataset_val, mvcnn, criterion, optimizer, batch_size) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) if is_best: save_checkpoint( { 'epoch': epoch + 1, 'state_dict': mvcnn.state_dict(), 'best_prec1': best_prec1, }, is_best, epoch) elif epoch % 5 is 0: save_checkpoint( { 'epoch': epoch + 1, 'state_dict': mvcnn.state_dict(), 'best_prec1': best_prec1, }, is_best, epoch)
# Train or Test if not args.demo: avg_tool = CumulativeAverager() vloss, is_best = torch.tensor(float(np.inf)), None if args.load_from is not None: if os.path.isfile(args.load_from): log_str = add_to_log("=> loading checkpoint '{}'".format(args.load_from)) checkpoint = torch.load(args.load_from) start = checkpoint['epoch'] vloss = checkpoint['best_val_loss'] net.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) args.lr = checkpoint['learning_rate'] log_str = add_to_log("=> loaded checkpoint '{}' (epoch {})" .format(args.load_from, checkpoint['epoch'])) else: log_str = add_to_log("=> no checkpoint found at '{}'".format(args.load_from)) else: start = 0 for epoch in range(start, args.epochs): train(epoch, losstype=losstype) val_loss = validate(losstype=losstype).cpu() scheduler.step(val_loss)
def main(args): crop_size = args.crop_size assert isinstance(crop_size, tuple) print_info_message( 'Running Model at image resolution {}x{} with batch size {}'.format( crop_size[0], crop_size[1], args.batch_size)) if not os.path.isdir(args.savedir): os.makedirs(args.savedir) num_gpus = torch.cuda.device_count() device = 'cuda' if num_gpus > 0 else 'cpu' print('device : ' + device) # Get a summary writer for tensorboard writer = SummaryWriter(log_dir=args.savedir, comment='Training and Validation logs') # # Training the model with 13 classes of CamVid dataset # TODO: This process should be done only if specified # if not args.finetune: train_dataset, val_dataset, class_wts, seg_classes, color_encoding = import_dataset( label_conversion=False) # 13 classes args.use_depth = False # 'use_depth' is always false for camvid print_info_message('Training samples: {}'.format(len(train_dataset))) print_info_message('Validation samples: {}'.format(len(val_dataset))) # Import model if args.model == 'espnetv2': from model.segmentation.espnetv2 import espnetv2_seg args.classes = seg_classes model = espnetv2_seg(args) elif args.model == 'espdnet': from model.segmentation.espdnet import espdnet_seg args.classes = seg_classes print("Trainable fusion : {}".format(args.trainable_fusion)) print("Segmentation classes : {}".format(seg_classes)) model = espdnet_seg(args) elif args.model == 'espdnetue': from model.segmentation.espdnet_ue import espdnetue_seg2 args.classes = seg_classes print("Trainable fusion : {}".format(args.trainable_fusion)) ("Segmentation classes : {}".format(seg_classes)) print(args.weights) model = espdnetue_seg2(args, False, fix_pyr_plane_proj=True) else: print_error_message('Arch: {} not yet supported'.format( args.model)) exit(-1) # Freeze batch normalization layers? if args.freeze_bn: freeze_bn_layer(model) # Set learning rates train_params = [{ 'params': model.get_basenet_params(), 'lr': args.lr }, { 'params': model.get_segment_params(), 'lr': args.lr * args.lr_mult }] # Define an optimizer optimizer = optim.SGD(train_params, lr=args.lr * args.lr_mult, momentum=args.momentum, weight_decay=args.weight_decay) # Compute the FLOPs and the number of parameters, and display it num_params, flops = show_network_stats(model, crop_size) try: writer.add_graph(model, input_to_model=torch.Tensor( 1, 3, crop_size[0], crop_size[1])) except: print_log_message( "Not able to generate the graph. Likely because your model is not supported by ONNX" ) #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx) criterion = SegmentationLoss(n_classes=seg_classes, loss_type=args.loss_type, device=device, ignore_idx=args.ignore_idx, class_wts=class_wts.to(device)) nid_loss = NIDLoss(image_bin=32, label_bin=seg_classes) if args.use_nid else None if num_gpus >= 1: if num_gpus == 1: # for a single GPU, we do not need DataParallel wrapper for Criteria. # So, falling back to its internal wrapper from torch.nn.parallel import DataParallel model = DataParallel(model) model = model.cuda() criterion = criterion.cuda() if args.use_nid: nid_loss.cuda() else: from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria model = DataParallelModel(model) model = model.cuda() criterion = DataParallelCriteria(criterion) criterion = criterion.cuda() if args.use_nid: nid_loss = DataParallelCriteria(nid_loss) nid_loss = nid_loss.cuda() if torch.backends.cudnn.is_available(): import torch.backends.cudnn as cudnn cudnn.benchmark = True cudnn.deterministic = True # Get data loaders for training and validation data train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=20, shuffle=False, pin_memory=True, num_workers=args.workers) # Get a learning rate scheduler lr_scheduler = get_lr_scheduler(args.scheduler) write_stats_to_json(num_params, flops) extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0]) # # Main training loop of 13 classes # start_epoch = 0 best_miou = 0.0 for epoch in range(start_epoch, args.epochs): lr_base = lr_scheduler.step(epoch) # set the optimizer with the learning rate # This can be done inside the MyLRScheduler lr_seg = lr_base * args.lr_mult optimizer.param_groups[0]['lr'] = lr_base optimizer.param_groups[1]['lr'] = lr_seg print_info_message( 'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}' .format(epoch, lr_base, lr_seg)) # Use different training functions for espdnetue if args.model == 'espdnetue': from utilities.train_eval_seg import train_seg_ue as train from utilities.train_eval_seg import val_seg_ue as val else: from utilities.train_eval_seg import train_seg as train from utilities.train_eval_seg import val_seg as val miou_train, train_loss = train(model, train_loader, optimizer, criterion, seg_classes, epoch, device=device, use_depth=args.use_depth, add_criterion=nid_loss) miou_val, val_loss = val(model, val_loader, criterion, seg_classes, device=device, use_depth=args.use_depth, add_criterion=nid_loss) batch_train = iter(train_loader).next() batch = iter(val_loader).next() in_training_visualization_img( model, images=batch_train[0].to(device=device), labels=batch_train[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='Segmentation/train', device=device) in_training_visualization_img(model, images=batch[0].to(device=device), labels=batch[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='Segmentation/val', device=device) # remember best miou and save checkpoint is_best = miou_val > best_miou best_miou = max(miou_val, best_miou) weights_dict = model.module.state_dict( ) if device == 'cuda' else model.state_dict() save_checkpoint( { 'epoch': epoch + 1, 'arch': args.model, 'state_dict': weights_dict, 'best_miou': best_miou, 'optimizer': optimizer.state_dict(), }, is_best, args.savedir, extra_info_ckpt) writer.add_scalar('Segmentation/LR/base', round(lr_base, 6), epoch) writer.add_scalar('Segmentation/LR/seg', round(lr_seg, 6), epoch) writer.add_scalar('Segmentation/Loss/train', train_loss, epoch) writer.add_scalar('Segmentation/Loss/val', val_loss, epoch) writer.add_scalar('Segmentation/mIOU/train', miou_train, epoch) writer.add_scalar('Segmentation/mIOU/val', miou_val, epoch) writer.add_scalar('Segmentation/Complexity/Flops', best_miou, math.ceil(flops)) writer.add_scalar('Segmentation/Complexity/Params', best_miou, math.ceil(num_params)) # Save the pretrained weights model_dict = copy.deepcopy(model.state_dict()) del model torch.cuda.empty_cache() # # Finetuning with 4 classes # args.ignore_idx = 4 train_dataset, val_dataset, class_wts, seg_classes, color_encoding = import_dataset( label_conversion=True) # 5 classes print_info_message('Training samples: {}'.format(len(train_dataset))) print_info_message('Validation samples: {}'.format(len(val_dataset))) #set_parameters_for_finetuning() # Import model if args.model == 'espnetv2': from model.segmentation.espnetv2 import espnetv2_seg args.classes = seg_classes model = espnetv2_seg(args) elif args.model == 'espdnet': from model.segmentation.espdnet import espdnet_seg args.classes = seg_classes print("Trainable fusion : {}".format(args.trainable_fusion)) print("Segmentation classes : {}".format(seg_classes)) model = espdnet_seg(args) elif args.model == 'espdnetue': from model.segmentation.espdnet_ue import espdnetue_seg2 args.classes = seg_classes print("Trainable fusion : {}".format(args.trainable_fusion)) print("Segmentation classes : {}".format(seg_classes)) print(args.weights) model = espdnetue_seg2(args, args.finetune, fix_pyr_plane_proj=True) else: print_error_message('Arch: {} not yet supported'.format(args.model)) exit(-1) if not args.finetune: new_model_dict = model.state_dict() # for k, v in model_dict.items(): # if k.lstrip('module.') in new_model_dict: # print('In:{}'.format(k.lstrip('module.'))) # else: # print('Not In:{}'.format(k.lstrip('module.'))) overlap_dict = { k.replace('module.', ''): v for k, v in model_dict.items() if k.replace('module.', '') in new_model_dict and new_model_dict[k.replace('module.', '')].size() == v.size() } no_overlap_dict = { k.replace('module.', ''): v for k, v in new_model_dict.items() if k.replace('module.', '') not in new_model_dict or new_model_dict[k.replace('module.', '')].size() != v.size() } print(no_overlap_dict.keys()) new_model_dict.update(overlap_dict) model.load_state_dict(new_model_dict) output = model(torch.ones(1, 3, 288, 480)) print(output[0].size()) print(seg_classes) print(class_wts.size()) #print(model_dict.keys()) #print(new_model_dict.keys()) criterion = SegmentationLoss(n_classes=seg_classes, loss_type=args.loss_type, device=device, ignore_idx=args.ignore_idx, class_wts=class_wts.to(device)) nid_loss = NIDLoss(image_bin=32, label_bin=seg_classes) if args.use_nid else None # Set learning rates args.lr /= 100 train_params = [{ 'params': model.get_basenet_params(), 'lr': args.lr }, { 'params': model.get_segment_params(), 'lr': args.lr * args.lr_mult }] # Define an optimizer optimizer = optim.SGD(train_params, lr=args.lr * args.lr_mult, momentum=args.momentum, weight_decay=args.weight_decay) if num_gpus >= 1: if num_gpus == 1: # for a single GPU, we do not need DataParallel wrapper for Criteria. # So, falling back to its internal wrapper from torch.nn.parallel import DataParallel model = DataParallel(model) model = model.cuda() criterion = criterion.cuda() if args.use_nid: nid_loss.cuda() else: from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria model = DataParallelModel(model) model = model.cuda() criterion = DataParallelCriteria(criterion) criterion = criterion.cuda() if args.use_nid: nid_loss = DataParallelCriteria(nid_loss) nid_loss = nid_loss.cuda() if torch.backends.cudnn.is_available(): import torch.backends.cudnn as cudnn cudnn.benchmark = True cudnn.deterministic = True # Get data loaders for training and validation data train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=20, shuffle=False, pin_memory=True, num_workers=args.workers) # Get a learning rate scheduler args.epochs = 50 lr_scheduler = get_lr_scheduler(args.scheduler) # Compute the FLOPs and the number of parameters, and display it num_params, flops = show_network_stats(model, crop_size) write_stats_to_json(num_params, flops) extra_info_ckpt = '{}_{}_{}_{}'.format(args.model, seg_classes, args.s, crop_size[0]) # # Main training loop of 13 classes # start_epoch = 0 best_miou = 0.0 for epoch in range(start_epoch, args.epochs): lr_base = lr_scheduler.step(epoch) # set the optimizer with the learning rate # This can be done inside the MyLRScheduler lr_seg = lr_base * args.lr_mult optimizer.param_groups[0]['lr'] = lr_base optimizer.param_groups[1]['lr'] = lr_seg print_info_message( 'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}' .format(epoch, lr_base, lr_seg)) # Use different training functions for espdnetue if args.model == 'espdnetue': from utilities.train_eval_seg import train_seg_ue as train from utilities.train_eval_seg import val_seg_ue as val else: from utilities.train_eval_seg import train_seg as train from utilities.train_eval_seg import val_seg as val miou_train, train_loss = train(model, train_loader, optimizer, criterion, seg_classes, epoch, device=device, use_depth=args.use_depth, add_criterion=nid_loss) miou_val, val_loss = val(model, val_loader, criterion, seg_classes, device=device, use_depth=args.use_depth, add_criterion=nid_loss) batch_train = iter(train_loader).next() batch = iter(val_loader).next() in_training_visualization_img(model, images=batch_train[0].to(device=device), labels=batch_train[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='SegmentationConv/train', device=device) in_training_visualization_img(model, images=batch[0].to(device=device), labels=batch[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='SegmentationConv/val', device=device) # remember best miou and save checkpoint is_best = miou_val > best_miou best_miou = max(miou_val, best_miou) weights_dict = model.module.state_dict( ) if device == 'cuda' else model.state_dict() save_checkpoint( { 'epoch': epoch + 1, 'arch': args.model, 'state_dict': weights_dict, 'best_miou': best_miou, 'optimizer': optimizer.state_dict(), }, is_best, args.savedir, extra_info_ckpt) writer.add_scalar('SegmentationConv/LR/base', round(lr_base, 6), epoch) writer.add_scalar('SegmentationConv/LR/seg', round(lr_seg, 6), epoch) writer.add_scalar('SegmentationConv/Loss/train', train_loss, epoch) writer.add_scalar('SegmentationConv/Loss/val', val_loss, epoch) writer.add_scalar('SegmentationConv/mIOU/train', miou_train, epoch) writer.add_scalar('SegmentationConv/mIOU/val', miou_val, epoch) writer.add_scalar('SegmentationConv/Complexity/Flops', best_miou, math.ceil(flops)) writer.add_scalar('SegmentationConv/Complexity/Params', best_miou, math.ceil(num_params)) writer.close()
class GetFeature(object): """Extract features Arguments model_weight_file: pre-trained model sys_device_ids: cpu/gpu """ def __init__(self, model_weight_file, sys_device_ids=''): if len(sys_device_ids) > 0: os.environ['CUDA_VISIBLE_DEVICES'] = sys_device_ids self.sys_device_ids = sys_device_ids self.model = DataParallel(Model()) if torch.cuda.is_available() and self.sys_device_ids != '': device = torch.device('cuda') else: device = torch.device('cpu') self.model.load_state_dict(torch.load(model_weight_file, map_location=device)) self.model.to(device) self.model.eval() def __call__(self, photo_path=None, batch_size=1): """ get global feature and local feature :param photo_path : either photo directory or a single image :param batch_size : useful only when photo_path is a directory :return: feature: numpy array, dim = num_images * 2048, photo_name: a list, len = num_images """ ''' if photo_dir is None and photo is None: raise self.InputError('Error: both photo_path ' 'and images is None.') if photo_dir and photo: raise self.InputError('Error: only need one argument, ' 'either photo_path or images.') ''' # input is a directory if os.path.isdir(photo_path): dataset = Data(photo_path, self._img_process) data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=8) features = torch.FloatTensor() photos = [] for batch, (images, names) in enumerate(data_loader): images = images.float() if torch.cuda.is_available() and self.sys_device_ids != '': images = images.to('cuda') feature = self.model(images).data.cpu() features = torch.cat((features, feature), 0) photos = photos + list(names) if batch % 10 == 0: print('processing batch: {}'.format(batch)) features = features.numpy() features = features/np.linalg.norm(features, axis=1, keepdims=True) return features, photos # input is a single image else: photo_name = photo_path.split('/')[-1] img = Image.open(photo_path) image = self._img_process(img) image = np.expand_dims(image, axis=0) image = torch.from_numpy(image).float() feature = self.model(image).data.numpy() feature = feature/np.linalg.norm(feature, axis=1, keepdims=True) return feature, [photo_name] def _img_process(self, img): img = img.resize((128, 384), resample=3) img = np.asarray(img) img = img[:, :, :3] img = img.astype(float) img = img / 255 im_mean = np.array([0.485, 0.456, 0.406]) im_std = np.array([0.229, 0.224, 0.225]) img = img - im_mean img = img / im_std img = np.transpose(img, (2, 0, 1)) return img
def main(args): if not os.path.exists(args.outputs_dir): os.makedirs(args.outputs_dir) print("===> Loading datasets") data_set = EvalDataset( args.test_lr, n_frames=args.n_frames, interval_list=args.interval_list, ) eval_loader = DataLoader(data_set, batch_size=args.batch_size, num_workers=args.workers) #### random seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) cudnn.benchmark = True #cudnn.deterministic = True print("===> Building model") #### create model model = EDVR_arch.EDVR(nf=args.nf, nframes=args.n_frames, groups=args.groups, front_RBs=args.front_RBs, back_RBs=args.back_RBs, center=args.center, predeblur=args.predeblur, HR_in=args.HR_in, w_TSA=args.w_TSA) print("===> Setting GPU") gups = args.gpus if args.gpus != 0 else torch.cuda.device_count() device_ids = list(range(gups)) model = DataParallel(model, device_ids=device_ids) model = model.cuda() # print(model) # optionally resume from a checkpoint if args.resume: if os.path.isdir(args.resume): # 获取目录中最后一个 pth_list = sorted(glob(os.path.join(args.resume, '*.pth'))) if len(pth_list) > 0: args.resume = pth_list[-1] if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) state_dict = checkpoint['state_dict'] new_state_dict = OrderedDict() for k, v in state_dict.items(): namekey = 'module.' + k # remove `module.` new_state_dict[namekey] = v model.load_state_dict(new_state_dict) #### training print("===> Eval") model.eval() with tqdm(total=(len(data_set) - len(data_set) % args.batch_size)) as t: for data in eval_loader: data_x = data['LRs'].cuda() names = data['files'] with torch.no_grad(): outputs = model(data_x).data.float().cpu() outputs = outputs * 255. outputs = outputs.clamp_(0, 255).numpy() for img, file in zip(outputs, names): img = np.transpose(img[[2, 1, 0], :, :], (1, 2, 0)) img = img.round() arr = file.split('/') dst_dir = os.path.join(args.outputs_dir, arr[-2]) if not os.path.exists(dst_dir): os.makedirs(dst_dir) dst_name = os.path.join(dst_dir, arr[-1]) cv2.imwrite(dst_name, img) t.update(len(names))
class Im2latex(BaseAgent): def __init__(self, cfg): super().__init__(cfg) self.device = get_device() cfg.device = self.device self.cfg = cfg # dataset train_dataset = Im2LatexDataset(cfg, mode="train") self.id2token = train_dataset.id2token self.token2id = train_dataset.token2id collate = custom_collate(self.token2id, cfg.max_len) self.train_loader = DataLoader(train_dataset, batch_size=cfg.bs, shuffle=cfg.data_shuffle, num_workers=cfg.num_w, collate_fn=collate, drop_last=True) if cfg.valid_img_path != "": valid_dataset = Im2LatexDataset(cfg, mode="valid", vocab={ 'id2token': self.id2token, 'token2id': self.token2id }) self.valid_loader = DataLoader(valid_dataset, batch_size=cfg.bs // cfg.beam_search_k, shuffle=cfg.data_shuffle, num_workers=cfg.num_w, collate_fn=collate, drop_last=True) # define models self.model = Im2LatexModel(cfg) # fill the parameters # weight initialization setting for name, param in self.model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: torch.nn.init.constant_(param, 0.0) elif 'weight' in name: torch.nn.init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue self.model = DataParallel(self.model) # define criterion self.criterion = cal_loss self.optimizer = torch.optim.Adam(params=self.model.parameters(), lr=cfg.lr, betas=(cfg.adam_beta_1, cfg.adam_beta_2)) milestones = cfg.milestones self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones, gamma=cfg.gamma, verbose=True) # initialize counter self.current_epoch = 1 self.current_iteration = 1 self.best_metric = 100 self.best_info = '' # set the manual seed for torch torch.cuda.manual_seed_all(self.cfg.seed) if self.cfg.cuda: self.model = self.model.to(self.device) self.logger.info("Program will run on *****GPU-CUDA***** ") else: self.logger.info("Program will run on *****CPU*****\n") # Model Loading from cfg if not found start from scratch. self.exp_dir = os.path.join('./experiments', cfg.exp_name) self.load_checkpoint(cfg.checkpoint_filename) # Summary Writer self.summary_writer = SummaryWriter( log_dir=os.path.join(self.exp_dir, 'summaries')) def load_checkpoint(self, file_name): """ Latest checkpoint loader :param file_name: name of the checkpoint file :return: """ try: self.logger.info("Loading checkpoint '{}'".format(file_name)) checkpoint = torch.load(file_name, map_location=self.device) self.current_epoch = checkpoint['epoch'] self.current_iteration = checkpoint['iteration'] self.model.load_state_dict(checkpoint['model'], strict=False) self.optimizer.load_state_dict(checkpoint['optimizer']) info = "Checkpoint loaded successfully from " self.logger.info( info + "'{}' at (epoch {}) at (iteration {})\n".format( file_name, checkpoint['epoch'], checkpoint['iteration'])) except OSError as e: self.logger.info("Checkpoint not found in '{}'.".format(file_name)) self.logger.info("**First time to train**") def save_checkpoint(self, file_name="checkpoint.pth", is_best=False): """ Checkpoint saver :param file_name: name of the checkpoint file :param is_best: boolean flag to indicate whether current checkpoint's accuracy is the best so far :return: """ state = { 'epoch': self.current_epoch, 'iteration': self.current_iteration, 'model': self.model.state_dict(), 'vocab': self.id2token, 'optimizer': self.optimizer.state_dict() } # save the state checkpoint_dir = os.path.join(self.exp_dir, 'checkpoints') if is_best: torch.save(state, os.path.join(checkpoint_dir, 'best.pt')) self.best_info = 'best: e{}_i{}'.format(self.current_epoch, self.current_iteration) else: file_name = "e{}-i{}.pt".format(self.current_epoch, self.current_iteration) torch.save(state, os.path.join(checkpoint_dir, file_name)) def run(self): """ The main operator :return: """ try: if self.cfg.mode == 'train': self.train() elif self.cfg.mode == 'predict': self.predict() except KeyboardInterrupt: self.logger.info("You have entered CTRL+C.. Wait to finalize") def train(self): """ Main training loop :return: """ prev_perplexity = 0 for e in range(self.current_epoch, self.cfg.epochs + 1): this_perplexity = self.train_one_epoch() self.save_checkpoint() self.scheduler.step() self.current_epoch += 1 if self.cfg.valid_img_path: self.validate() def train_one_epoch(self): """ One epoch of training :return: """ tqdm_bar = tqdm(enumerate(self.train_loader, 1), total=len(self.train_loader)) self.model.train() last_avg_perplexity, avg_perplexity = 0, 0 for i, (imgs, tgt) in tqdm_bar: imgs = imgs.float().to(self.device) tgt = tgt.long().to(self.device) # [B, MAXLEN, VOCABSIZE] logits = self.model(imgs, tgt, is_train=True) loss = self.criterion(logits, tgt) avg_perplexity += loss.item() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip) self.optimizer.step() self.current_iteration += 1 # logging if i % self.cfg.log_freq == 0: avg_perplexity = avg_perplexity / self.cfg.log_freq self.summary_writer.add_scalar( 'perplexity/train', avg_perplexity, global_step=self.current_iteration) self.summary_writer.add_scalar( 'lr', self.scheduler.get_last_lr(), global_step=self.current_iteration) tqdm_bar.set_description("e{} | avg_perplexity: {:.3f}".format( self.current_epoch, avg_perplexity)) # save if best if avg_perplexity < self.best_metric: self.save_checkpoint(is_best=True) self.best_metric = avg_perplexity last_avg_perplexity = avg_perplexity avg_perplexity = 0 mask = (tgt[0] != 2) pred = str(logits[0].argmax(1)[mask].cpu().detach().tolist()) gt = str(tgt[0][mask].cpu().tolist()) self.summary_writer.add_text('example/train', pred + ' \n' + gt, global_step=self.current_iteration) return last_avg_perplexity def validate(self): """ One cycle of model validation :return: """ tqdm_bar = tqdm(enumerate(self.valid_loader, 1), total=len(self.valid_loader)) self.model.eval() acc = 0 with torch.no_grad(): for i, (imgs, tgt) in tqdm_bar: imgs = imgs.to(self.device).float() tgt = tgt.to(self.device).long() logits = self.model( imgs, is_train=False).long() # [B, MAXLEN, VOCABSIZE] # mask = (tgt == 2) # tgt[mask] = 1 # logits[mask] = 1 acc += torch.all(tgt == logits, dim=1).sum() / imgs.size(0) # print('t', tgt) # print('l', logits) tqdm_bar.set_description('acc {:.4f}'.format(acc / i)) if i % self.cfg.log_freq == 0: self.summary_writer.add_scalar( 'accuracy/valid', acc.item() / i, global_step=self.current_iteration) def predict(self): """ get predict results :return: """ from torchvision import transforms from pathlib import Path from PIL import Image from time import time self.model.eval() transform = transforms.ToTensor() image_path = Path(self.cfg.test_img_path) t = time() with torch.no_grad(): images = [] imgPath = list(image_path.glob('*.jpg')) + list( image_path.glob('*.png')) for i, img in enumerate(imgPath): print(i, ':', img) img = Image.open(img) img = transform(img) images.append(img) images = torch.stack(images, dim=0) out = self.model(images) # [B, max_len, vocab_size] # out = out.argmax(2) for i, output in enumerate(out): print( i, ' '.join([ self.id2token[out.item()] for out in output if out.item() != 1 ])) print(time() - t) def finalize(self): """ Finalizes all the operations of the 2 Main classes of the process, the operator and the data loader :return: """ print(self.best_info) pass
class SRGANModel(BaseModel): def __init__(self, opt, dataset=None): super(SRGANModel, self).__init__(opt) if dataset: self.cri_text = True if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) if self.is_train: self.netD = networks.define_D(opt).to(self.device) if opt['dist']: self.netD = DistributedDataParallel( self.netD, device_ids=[torch.cuda.current_device()]) else: self.netD = DataParallel(self.netD) self.netG.train() self.netD.train() # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) if opt['dist']: pass # do not need to use DistributedDataParallel for netF else: self.netF = DataParallel(self.netF) if self.cri_text: from lib.models.model_builder import ModelBuilder self.netT = ModelBuilder( arch="ResNet_ASTER", rec_num_classes=dataset.rec_num_classes, sDim=512, attDim=512, max_len_labels=100, eos=dataset.char2id[dataset.EOS], STN_ON=True).to(self.device) self.netT = DataParallel(self.netT) self.netT.eval() from lib.util.serialization import load_checkpoint checkpoint = load_checkpoint(train_opt['text_model']) self.netT.load_state_dict(checkpoint['state_dict']) # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1_G'], train_opt['beta2_G'])) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], weight_decay=wd_D, betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() self.print_network() # print network self.load() # load G and D if needed def feed_data(self, data, need_GT=True): self.var_L = data['LQ'].to(self.device) # LQ if need_GT: self.var_H = data['GT'].to(self.device) # GT input_ref = data['ref'] if 'ref' in data else data['GT'] self.var_ref = input_ref.to(self.device) def optimize_parameters(self, step, text_input=None): # G for p in self.netD.parameters(): p.requires_grad = False self.optimizer_G.zero_grad() self.fake_H = self.netG(self.var_L) l_g_total = 0 if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: # pixel loss l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) l_g_total += l_g_pix if self.cri_fea: # feature loss real_fea = self.netF(self.var_H).detach() fake_fea = self.netF(self.fake_H) l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) l_g_total += l_g_fea if self.cri_text: _, label, length = text_input input_dict = {} input_dict['images'] = self.fake_H input_dict['rec_target'] = label input_dict['rec_length'] = length output_dict = self.netT(input_dict) l_g_total += output_dict['losses']['loss_rec'].mean(dim=0) if self.opt['train']['gan_type'] == 'gan': pred_g_fake = self.netD(self.fake_H) l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) elif self.opt['train']['gan_type'] == 'ragan': pred_d_real = self.netD(self.var_ref).detach() pred_g_fake = self.netD(self.fake_H) l_g_gan = self.l_gan_w * ( self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 l_g_total += l_g_gan l_g_total.backward() self.optimizer_G.step() # D for p in self.netD.parameters(): p.requires_grad = True self.optimizer_D.zero_grad() if self.opt['train']['gan_type'] == 'gan': # need to forward and backward separately, since batch norm statistics differ # real pred_d_real = self.netD(self.var_ref) l_d_real = self.cri_gan(pred_d_real, True) l_d_real.backward() # fake pred_d_fake = self.netD( self.fake_H.detach()) # detach to avoid BP to G l_d_fake = self.cri_gan(pred_d_fake, False) l_d_fake.backward() elif self.opt['train']['gan_type'] == 'ragan': # pred_d_real = self.netD(self.var_ref) # pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G # l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) # l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) # l_d_total = (l_d_real + l_d_fake) / 2 # l_d_total.backward() pred_d_fake = self.netD(self.fake_H.detach()).detach() pred_d_real = self.netD(self.var_ref) l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 l_d_real.backward() pred_d_fake = self.netD(self.fake_H.detach()) l_d_fake = self.cri_gan( pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 l_d_fake.backward() self.optimizer_D.step() # set log if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: self.log_dict['l_g_pix'] = l_g_pix.item() if self.cri_fea: self.log_dict['l_g_fea'] = l_g_fea.item() self.log_dict['l_g_gan'] = l_g_gan.item() self.log_dict['l_d_real'] = l_d_real.item() self.log_dict['l_d_fake'] = l_d_fake.item() self.log_dict['D_real'] = torch.mean(pred_d_real.detach()) self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) def test(self): self.netG.eval() with torch.no_grad(): self.fake_H = self.netG(self.var_L) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict['LQ'] = self.var_L.detach()[0].float().cpu() out_dict['rlt'] = self.fake_H.detach()[0].float().cpu() if need_GT: out_dict['GT'] = self.var_H.detach()[0].float().cpu() return out_dict def print_network(self): # Generator s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( 'Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.is_train: # Discriminator s, n = self.get_network_description(self.netD) if isinstance(self.netD, nn.DataParallel) or isinstance( self.netD, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netD.__class__.__name__, self.netD.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netD.__class__.__name__) if self.rank <= 0: logger.info( 'Network D structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.cri_fea: # F, Perceptual Network s, n = self.get_network_description(self.netF) if isinstance(self.netF, nn.DataParallel) or isinstance( self.netF, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netF.__class__.__name__, self.netF.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netF.__class__.__name__) if self.rank <= 0: logger.info( 'Network F structure: {}, with parameters: {:,d}'. format(net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) load_path_D = self.opt['path']['pretrain_model_D'] if self.opt['is_train'] and load_path_D is not None: logger.info('Loading model for D [{:s}] ...'.format(load_path_D)) self.load_network(load_path_D, self.netD, self.opt['path']['strict_load']) def save(self, iter_step): self.save_network(self.netG, 'G', iter_step) self.save_network(self.netD, 'D', iter_step)
class VisualizeProcess: def __init__(self): self.net = ET_Net() if (ARGS['gpu']): self.net = DataParallel(module=self.net.cuda()) self.net.load_state_dict(torch.load(ARGS['weight'])) self.train_dataset = get_dataset(dataset_name=ARGS['dataset'], part='train') self.val_dataset = get_dataset(dataset_name=ARGS['dataset'], part='val') def visualize(self): start = time.time() self.net.eval() val_batch_size = min(ARGS['batch_size'], len(self.val_dataset)) val_dataloader = DataLoader(self.val_dataset, batch_size=val_batch_size) for batch_index, items in enumerate(val_dataloader): images, labels, edges = items['image'], items['label'], items[ 'edge'] images = images.float() labels = labels.long() edges = edges.long() if ARGS['gpu']: labels = labels.cuda() images = images.cuda() edges = edges.cuda() print('image shape:', images.size()) with torch.no_grad(): outputs_edge, outputs = self.net(images) pred = torch.max(outputs, dim=1)[1] iou = torch.sum(pred[0] & labels[0]) / ( torch.sum(pred[0] | labels[0]) + 1e-6) mean = torch.FloatTensor([123.68, 116.779, 103.939]).reshape( (3, 1, 1)) / 255. images = images + mean.cuda() # images *= 255. print('pred min: ', pred[0].min(), ' max: ', pred[0].max()) print('label min:', labels[0].min(), ' max: ', labels[0].max()) print('edge min:', edges[0].min(), ' max: ', edges[0].max()) print('output edge min:', outputs_edge[0].min(), ' max: ', outputs_edge[0].max()) print('IoU:', iou) print('Intersect num:', torch.sum(pred[0] & labels[0])) print('Union num:', torch.sum(pred[0] | labels[0])) plt.subplot(221) plt.imshow(images[0].cpu().numpy().transpose( (1, 2, 0))), plt.axis('off') plt.subplot(222) plt.imshow(labels[0].cpu().numpy(), cmap='gray'), plt.axis('off') plt.subplot(223) # plt.imshow(pred[0].cpu().numpy(), cmap='gray'), plt.axis('off') plt.imshow(outputs[0, 1].cpu().numpy(), cmap='gray'), plt.axis('off') plt.subplot(224) plt.imshow(outputs_edge[0, 1].cpu().numpy(), cmap='gray'), plt.axis('off') plt.show() # update training loss for each iteration # self.writer.add_scalar('Train/loss', loss.item(), n_iter) finish = time.time() print('validating time consumed: {:.2f}s'.format(finish - start))
class BaseEngine(object): def __init__(self, args): self._make_dataset(args) self._make_model(args) tc.manual_seed(args.seed) if args.cuda and tc.cuda.is_available(): tc.cuda.manual_seed_all(args.seed) if tc.cuda.device_count() > 1: self.batch_size = args.batch_size * tc.cuda.device_count() self.model = DataParallel(self.model) else: self.batch_size = args.batch_size self.model = self.model.cuda() else: self.batch_size = args.batch_size self._make_optimizer(args) self._make_loss(args) self._make_metric(args) self.num_training_samples = args.num_training_samples self.tag = args.tag or 'default' self.dump_dir = get_dir(args.dump_dir) self.train_logger = get_logger('train.{}.{}'.format( self.__class__.__name__, self.tag)) def _make_dataset(self, args): raise NotImplementedError def _make_model(self, args): raise NotImplementedError def _make_optimizer(self, args): raise NotImplementedError def _make_loss(self, args): raise NotImplementedError def _make_metric(self, args): raise NotImplementedError def dump(self, epoch, model=True, optimizer=True, decayer=True): state = {'epoch': epoch} if model: state['model'] = self.model.state_dict() if optimizer: state['optimizer'] = self.optimizer.state_dict() if decayer and (getattr(self, 'decayer', None) is not None): state['decayer'] = self.decayer.state_dict() tc.save(state, os.path.join(self.dump_dir, 'state_{}.pkl'.format(self.tag))) self.train_logger.info('Checkpoint {} dumped'.format(self.tag)) def load(self, model=True, optimizer=True, decayer=True): try: state = tc.load( os.path.join(self.dump_dir, 'state_{}.pkl'.format(self.tag))) except FileNotFoundError: return 0 if model and (state.get('model') is not None): self.model.load_state_dict(state['model']) if optimizer and (state.get('optimizer') is not None): self.optimizer.load_state_dict(state['optimizer']) if decayer and (state.get('decayer') is not None) and (getattr( self, 'decayer', None) is not None): self.decayer.load_state_dict(state['decayer']) return state['epoch'] def eval(self): raise NotImplementedError def test(self): raise NotImplementedError def train(self, num_epochs, resume=False): raise NotImplementedError
class Trainer(): def __init__(self, config, debug=False): self.config = config self.epoch = 0 self.iteration = 0 if debug: self.config['trainer']['save_freq'] = 5 self.config['trainer']['valid_freq'] = 5 self.config['trainer']['iterations'] = 5 # setup data set and data loader self.train_dataset = AVEDataset(config['data_loader'], split='train') self.train_sampler = None self.train_args = config['trainer'] self.train_loader = DataLoader( self.train_dataset, batch_size=self.train_args['batch_size'], shuffle=True, num_workers=self.train_args['num_workers'], pin_memory=True) # set loss functions self.adversarial_loss = AdversarialLoss( type=self.config['losses']['GAN_LOSS']) self.adversarial_loss = self.adversarial_loss.to(self.config['device']) self.l1_loss = nn.L1Loss() # setup models including generator and discriminator net = importlib.import_module('model.' + config['model']) self.netG = net.InpaintGenerator() self.netG = self.netG.to(self.config['device']) self.netD = net.Discriminator( in_channels=3, use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge') self.netD = self.netD.to(self.config['device']) self.optimG = torch.optim.Adam(self.netG.parameters(), lr=config['trainer']['lr'], betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2'])) self.optimD = torch.optim.Adam(self.netD.parameters(), lr=config['trainer']['lr'], betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2'])) self.load() if config['distributed']: self.netG = DataParallel(self.netG) self.netD = DataParallel(self.netD) # set summary writer self.dis_writer = None self.gen_writer = None self.summary = {} if self.config['global_rank'] == 0 or (not config['distributed']): self.dis_writer = SummaryWriter( os.path.join(config['save_dir'], 'dis')) self.gen_writer = SummaryWriter( os.path.join(config['save_dir'], 'gen')) # get current learning rate def get_lr(self): return self.optimG.param_groups[0]['lr'] # learning rate scheduler, step def adjust_learning_rate(self): decay = 0.1**( min(self.iteration, self.config['trainer']['niter_steady']) // self.config['trainer']['niter']) new_lr = self.config['trainer']['lr'] * decay if new_lr != self.get_lr(): for param_group in self.optimG.param_groups: param_group['lr'] = new_lr for param_group in self.optimD.param_groups: param_group['lr'] = new_lr # add summary def add_summary(self, writer, name, val): if name not in self.summary: self.summary[name] = 0 self.summary[name] += val if writer is not None and self.iteration % 100 == 0: writer.add_scalar(name, self.summary[name] / 100, self.iteration) self.summary[name] = 0 # add image def add_images(self, writer, input_image, output_image, gt_image): if writer is not None and self.iteration % 100 == 0: b, t, c, h, w = input_image.size() input_image = input_image.view(b * t, c, h, w) output_image = output_image.view(b * t, c, h, w) gt_image = gt_image.view(b * t, c, h, w) writer.add_image("input/input_image", make_grid((input_image + 1) / 2, t), self.iteration) writer.add_image("output/output_image", make_grid((output_image + 1) / 2, t), self.iteration) writer.add_image("output/gt_image", make_grid((gt_image + 1) / 2, t), self.iteration) # load netG and netD def load(self): model_path = self.config['save_dir'] if os.path.isfile(os.path.join(model_path, 'latest.ckpt')): latest_epoch = open(os.path.join(model_path, 'latest.ckpt'), 'r').read().splitlines()[-1] else: ckpts = [ os.path.basename(i).split('.pth')[0] for i in glob.glob(os.path.join(model_path, '*.pth')) ] ckpts.sort() latest_epoch = ckpts[-1] if len(ckpts) > 0 else None if latest_epoch is not None: gen_path = os.path.join( model_path, 'gen_{}.pth'.format(str(latest_epoch).zfill(5))) dis_path = os.path.join( model_path, 'dis_{}.pth'.format(str(latest_epoch).zfill(5))) opt_path = os.path.join( model_path, 'opt_{}.pth'.format(str(latest_epoch).zfill(5))) if self.config['global_rank'] == 0: print('Loading model from {}...'.format(gen_path)) data = torch.load(gen_path, map_location=self.config['device']) self.netG.load_state_dict(data['netG']) data = torch.load(dis_path, map_location=self.config['device']) self.netD.load_state_dict(data['netD']) data = torch.load(opt_path, map_location=self.config['device']) self.optimG.load_state_dict(data['optimG']) self.optimD.load_state_dict(data['optimD']) self.epoch = data['epoch'] self.iteration = data['iteration'] else: if self.config['global_rank'] == 0: print( 'Warnning: There is no trained model found. An initialized model will be used.' ) # save parameters every eval_epoch def save(self, it): if self.config['global_rank'] == 0: gen_path = os.path.join(self.config['save_dir'], 'gen_{}.pth'.format(str(it).zfill(5))) dis_path = os.path.join(self.config['save_dir'], 'dis_{}.pth'.format(str(it).zfill(5))) opt_path = os.path.join(self.config['save_dir'], 'opt_{}.pth'.format(str(it).zfill(5))) print('\nsaving model to {} ...'.format(gen_path)) if isinstance(self.netG, torch.nn.DataParallel) or isinstance( self.netG, DDP): netG = self.netG.module netD = self.netD.module else: netG = self.netG netD = self.netD torch.save({'netG': netG.state_dict()}, gen_path) torch.save({'netD': netD.state_dict()}, dis_path) torch.save( { 'epoch': self.epoch, 'iteration': self.iteration, 'optimG': self.optimG.state_dict(), 'optimD': self.optimD.state_dict() }, opt_path) os.system('echo {} > {}'.format( str(it).zfill(5), os.path.join(self.config['save_dir'], 'latest.ckpt'))) # train entry def train(self): pbar = range(int(self.train_args['iterations'])) if self.config['global_rank'] == 0: pbar = tqdm(pbar, initial=self.iteration, dynamic_ncols=True, smoothing=0.01) while True: self.epoch += 1 # if self.config['distributed']: # self.train_sampler.set_epoch(self.epoch) self._train_epoch(pbar) if self.iteration > self.train_args['iterations']: break print('\nEnd training....') # process input and calculate loss every training epoch def _train_epoch(self, pbar): device = self.config['device'] for frames, masks in self.train_loader: self.adjust_learning_rate() self.iteration += 1 frames, masks = frames.to(device), masks.to(device) b, t, c, h, w = frames.size() masked_frame = (frames * (1 - masks).float()) pred_img = self.netG(masked_frame, masks) frames = frames.view(b * t, c, h, w) masks = masks.view(b * t, 1, h, w) comp_img = frames * (1. - masks) + masks * pred_img gen_loss = 0 dis_loss = 0 # discriminator adversarial loss real_vid_feat = self.netD(frames) fake_vid_feat = self.netD(comp_img.detach()) dis_real_loss = self.adversarial_loss(real_vid_feat, True, True) dis_fake_loss = self.adversarial_loss(fake_vid_feat, False, True) dis_loss += (dis_real_loss + dis_fake_loss) / 2 self.add_summary(self.dis_writer, 'loss/dis_vid_fake', dis_fake_loss.item()) self.add_summary(self.dis_writer, 'loss/dis_vid_real', dis_real_loss.item()) self.optimD.zero_grad() dis_loss.backward() self.optimD.step() # generator adversarial loss gen_vid_feat = self.netD(comp_img) gan_loss = self.adversarial_loss(gen_vid_feat, True, False) gan_loss = gan_loss * self.config['losses']['adversarial_weight'] gen_loss += gan_loss self.add_summary(self.gen_writer, 'loss/gan_loss', gan_loss.item()) # generator l1 loss hole_loss = self.l1_loss(pred_img * masks, frames * masks) hole_loss = hole_loss / torch.mean( masks) * self.config['losses']['hole_weight'] gen_loss += hole_loss self.add_summary(self.gen_writer, 'loss/hole_loss', hole_loss.item()) valid_loss = self.l1_loss(pred_img * (1 - masks), frames * (1 - masks)) valid_loss = valid_loss / torch.mean( 1 - masks) * self.config['losses']['valid_weight'] gen_loss += valid_loss self.add_summary(self.gen_writer, 'loss/valid_loss', valid_loss.item()) self.optimG.zero_grad() gen_loss.backward() self.optimG.step() self.add_images(self.gen_writer, masked_frame.cpu().detach(), comp_img.cpu().detach(), frames.cpu().detach()) # console logs if self.config['global_rank'] == 0: pbar.update(1) pbar.set_description(( f"d: {dis_loss.item():.3f}; g: {gan_loss.item():+.3f}; " f"hole: {hole_loss.item():.3f}; valid: {valid_loss.item():.3f}" )) # saving models if self.iteration % self.train_args['save_freq'] == 0: self.save(int(self.iteration // self.train_args['save_freq'])) if self.iteration > self.train_args['iterations']: break
def main(): parser = argparse.ArgumentParser(description="PyTorch Object Detection Training") parser.add_argument( "--config-file", default="", metavar="FILE", help="path to config file", type=str, ) args = parser.parse_args() cfg.merge_from_file(args.config_file) cfg.freeze() viewer = Visualizer(cfg.OUTPUT_DIR) #Model model = build_model(cfg) model = DataParallel(model).cuda() if cfg.MODEL.WEIGHT !="": model.module.backbone.load_state_dict(torch.load(cfg.MODEL.WEIGHT)) #freeze backbone for key,val in model.module.backbone.named_parameters(): val.requires_grad = False batch_time = AverageMeter() data_time = AverageMeter() #optimizer optimizer = getattr(torch.optim,cfg.SOLVER.OPTIM)(model.parameters(),lr = cfg.SOLVER.BASE_LR,weight_decay=cfg.SOLVER.WEIGHT_DECAY) lr_sche = torch.optim.lr_scheduler.MultiStepLR(optimizer,cfg.SOLVER.STEPS,gamma= cfg.SOLVER.GAMMA) #dataset datasets = make_dataset(cfg,is_train=False) dataloaders = make_dataloaders(cfg,datasets,False) iter_epoch = (cfg.SOLVER.MAX_ITER)//len(dataloaders[0])+1 if not os.path.exists(cfg.OUTPUT_DIR): os.mkdir(cfg.OUTPUT_DIR) ite = 0 batch_it = [i *cfg.SOLVER.IMS_PER_BATCH for i in range(1,4)] # start time start = time.time() inference_list = ['resnet18_14.pth','resnet18_13.pth','resnet18_12.pth','resnet18_11.pth','resnet18_10.pth'] for inference_weight in inference_list: model.load_state_dict(torch.load(os.path.join(resume_dir,inference_weight))) model.eval() total_count = 0 one_count = 0 two_count = 0 three_count = 0 one_number = 0 two_number = 0 three_number = 0 for dataloader in dataloaders: for imgs,labels,types in tqdm.tqdm(dataloader,desc="dataloader:"): types = np.asarray(types) lr_sche.step() data_time.update(time.time() - start) inputs = torch.cat([imgs[0].cuda(),imgs[1].cuda(),imgs[2].cuda()],dim=0) with torch.no_grad(): features = model(inputs) acc,batch_loss = loss_opts.batch_triple_loss_acc(features,labels,types,size_average=True) print(batch_loss) xxx total_count+= batch_loss.shape[0]-acc ONE_CLASS = (batch_loss[np.nonzero(types=='ONE_CLASS_TRIPLET')[0]]) TWO_CLASS = (batch_loss[np.nonzero(types=='TWO_CLASS_TRIPLET')[0]]) THREE_CLASS = (batch_loss[np.nonzero(types=='THREE_CLASS_TRIPLET')[0]]) one_count += ONE_CLASS.shape[0] - torch.nonzero(ONE_CLASS).shape[0] two_count += TWO_CLASS.shape[0] - torch.nonzero(TWO_CLASS).shape[0] three_count += THREE_CLASS.shape[0] - torch.nonzero(THREE_CLASS).shape[0] one_number+=ONE_CLASS.shape[0] two_number+=TWO_CLASS.shape[0] three_number+=THREE_CLASS.shape[0] # viewer.line("train/loss",loss.item()*100,ite) print(inference_weight,total_count/(one_number+two_number+three_number),one_count/one_number,two_count/two_number,three_count/three_number)
class CalculateMetricProcess: def __init__(self): self.net = ET_Net() if (ARGS['gpu']): self.net = DataParallel(module=self.net.cuda()) self.net.load_state_dict(torch.load(ARGS['weight'])) self.metric_dataset = get_dataset(dataset_name=ARGS['dataset'], part='metric') def predict(self): start = time.time() self.net.eval() metric_dataloader = DataLoader( self.metric_dataset, batch_size=1) # only support batch size = 1 os.makedirs(ARGS['prediction_save_folder'], exist_ok=True) y_true = [] y_pred = [] for items in metric_dataloader: images, labels, mask = items['image'], items['label'], items[ 'mask'] images = images.float() print('image shape:', images.size()) image_patches, big_h, big_w = get_test_patches( images, ARGS['crop_size'], ARGS['stride_size']) test_patch_dataloader = DataLoader(image_patches, batch_size=ARGS['batch_size'], shuffle=False, drop_last=False) test_results = [] print('Number of batches for testing:', len(test_patch_dataloader)) for patches in test_patch_dataloader: if ARGS['gpu']: patches = patches.cuda() with torch.no_grad(): result_patches_edge, result_patches = self.net(patches) test_results.append(result_patches.cpu()) test_results = torch.cat(test_results, dim=0) # merge test_results = recompone_overlap(test_results, ARGS['crop_size'], ARGS['stride_size'], big_h, big_w) test_results = test_results[:, 1, :images.size(2), :images.size(3)] y_pred.append(test_results[mask == 1].reshape(-1)) y_true.append(labels[mask == 1].reshape(-1)) y_pred = torch.cat(y_pred).numpy() y_true = torch.cat(y_true).numpy() calc_metrics(y_pred, y_true) finish = time.time() print('Calculating metric time consumed: {:.2f}s'.format(finish - start))
def main(args): print("===> Loading datasets") data_set = DatasetLoader(args.data_lr, args.data_hr, size_w=args.size_w, size_h=args.size_h, scale=args.scale, n_frames=args.n_frames, interval_list=args.interval_list, border_mode=args.border_mode, random_reverse=args.random_reverse) train_loader = DataLoader(data_set, batch_size=args.batch_size, num_workers=args.workers, shuffle=True, pin_memory=False, drop_last=True) #### random seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) cudnn.benchmark = True #cudnn.deterministic = True print("===> Building model") #### create model model = EDVR_arch.EDVR(nf=args.nf, nframes=args.n_frames, groups=args.groups, front_RBs=args.front_RBs, back_RBs=args.back_RBs, center=args.center, predeblur=args.predeblur, HR_in=args.HR_in, w_TSA=args.w_TSA) criterion = CharbonnierLoss() print("===> Setting GPU") gups = args.gpus if args.gpus != 0 else torch.cuda.device_count() device_ids = list(range(gups)) model = DataParallel(model, device_ids=device_ids) model = model.cuda() criterion = criterion.cuda() # print(model) start_epoch = args.start_epoch # optionally resume from a checkpoint if args.resume: if os.path.isdir(args.resume): # 获取目录中最后一个 pth_list = sorted(glob(os.path.join(args.resume, '*.pth'))) if len(pth_list) > 0: args.resume = pth_list[-1] if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] + 1 state_dict = checkpoint['state_dict'] new_state_dict = OrderedDict() for k, v in state_dict.items(): namekey = 'module.' + k # remove `module.` new_state_dict[namekey] = v model.load_state_dict(new_state_dict) # 如果文件中有lr,则不用启动参数 args.lr = checkpoint.get('lr', args.lr) # 如果设置了 start_epoch 则不用checkpoint中的epoch参数 start_epoch = args.start_epoch if args.start_epoch != 0 else start_epoch #如果use_current_lr大于0 测代替作为lr args.lr = args.use_current_lr if args.use_current_lr > 0 else args.lr optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.weight_decay, betas=(args.beta1, args.beta2), eps=1e-8) #### training print("===> Training") for epoch in range(start_epoch, args.epochs): adjust_lr(optimizer, epoch) if args.use_tqdm == 1: losses, psnrs = one_epoch_train_tqdm( model, optimizer, criterion, len(data_set), train_loader, epoch, args.epochs, args.batch_size, optimizer.param_groups[0]["lr"]) else: losses, psnrs = one_epoch_train_logger( model, optimizer, criterion, len(data_set), train_loader, epoch, args.epochs, args.batch_size, optimizer.param_groups[0]["lr"]) # save model # if epoch %9 != 0: # continue model_out_path = os.path.join( args.checkpoint, "model_epoch_%04d_edvr_loss_%.3f_psnr_%.3f.pth" % (epoch, losses.avg, psnrs.avg)) if not os.path.exists(args.checkpoint): os.makedirs(args.checkpoint) torch.save( { 'state_dict': model.module.state_dict(), "epoch": epoch, 'lr': optimizer.param_groups[0]["lr"] }, model_out_path)
def create_and_test_triplet_network(batch_triplet_indices_loader, experiment_name, path_to_emb_net, unseen_triplets, dataset_name, model_name, logger, test_n, n, dim, layers, learning_rate=5e-2, epochs=20, hl_size=100): """ Description: Constructs the OENN network, defines an optimizer and trains the network on the data w.r.t triplet loss. :param model_name: :param dataset_name: :param test_n: :param path_to_emb_net: Data loader object. Gives triplet indices in batches. :param n: # points :param dim: # features/ dimensions :param layers: # layers :param learning_rate: learning rate of optimizer. :param epochs: # epochs :param hl_size: # width of the hidden layer :param unseen_triplets: #TODO :param logger: # for logging :return: """ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(device) digits = int(math.ceil(math.log2(n))) # Define train model emb_net_train = define_model(model_name=model_name, digits=digits, hl_size=hl_size, dim=dim, layers=layers) emb_net_train = emb_net_train.to(device) for param in emb_net_train.parameters(): param.requires_grad = False if torch.cuda.device_count() > 1: emb_net_train = DataParallel(emb_net_train) print('multi-gpu') checkpoint = torch.load(path_to_emb_net)['model_state_dict'] key_word = list(checkpoint.keys())[0].split('.')[0] if key_word == 'module': from collections import OrderedDict new_state_dict = OrderedDict() for k, v in checkpoint.items(): name = k[7:] # remove `module.` new_state_dict[name] = v emb_net_train.load_state_dict(new_state_dict) else: emb_net_train.load_state_dict(checkpoint) emb_net_train.eval() # Define test model emb_net_test = define_model(model_name=model_name, digits=digits, hl_size=hl_size, dim=dim, layers=layers) emb_net_test = emb_net_test.to(device) if torch.cuda.device_count() > 1: emb_net_test = DataParallel(emb_net_test) print('multi-gpu') # Optimizer optimizer = torch.optim.Adam(emb_net_test.parameters(), lr=learning_rate) criterion = nn.TripletMarginLoss(margin=1, p=2) criterion = criterion.to(device) logger.info('#### Dataset Selection #### \n') logger.info('dataset:', dataset_name) logger.info('#### Network and learning parameters #### \n') logger.info('------------------------------------------ \n') logger.info('Model Name: ' + model_name + '\n') logger.info('Number of hidden layers: ' + str(layers) + '\n') logger.info('Hidden layer width: ' + str(hl_size) + '\n') logger.info('Embedding dimension: ' + str(dim) + '\n') logger.info('Learning rate: ' + str(learning_rate) + '\n') logger.info('Number of epochs: ' + str(epochs) + '\n') logger.info(' #### Training begins #### \n') logger.info('---------------------------\n') digits = int(math.ceil(math.log2(n))) bin_array = data_utils.get_binary_array(n, digits) trip_data = torch.tensor(bin_array[unseen_triplets]) trip = trip_data.squeeze().to(device).float() # Training begins train_time = 0 for ep in range(epochs): # Epoch is one pass over the dataset epoch_loss = 0 for batch_ind, trips in enumerate(batch_triplet_indices_loader): sys.stdout.flush() trip = trips.squeeze().to(device).float() # Training time begin_train_time = time.time() # Forward pass embedded_a = emb_net_test(trip[:, :digits]) embedded_p = emb_net_train(trip[:, digits:2 * digits]) embedded_n = emb_net_train(trip[:, 2 * digits:]) # Compute loss loss = criterion(embedded_a, embedded_p, embedded_n).to(device) # Backward pass optimizer.zero_grad() loss.backward() optimizer.step() # End of training end_train_time = time.time() if batch_ind % 50 == 0: logger.info('Epoch: ' + str(ep) + ' Mini batch: ' + str(batch_ind) + '/' + str(len(batch_triplet_indices_loader)) + ' Loss: ' + str(loss.item())) sys.stdout.flush() # Prints faster to the out file epoch_loss += loss.item() train_time = train_time + end_train_time - begin_train_time # Log logger.info('Epoch ' + str(ep) + ' - Average Epoch Loss: ' + str(epoch_loss / len(batch_triplet_indices_loader)) + ' Training time ' + str(train_time)) sys.stdout.flush() # Prints faster to the out file # Saving the results logger.info('Saving the models and the results') sys.stdout.flush() # Prints faster to the out file os.makedirs('test_checkpoints', mode=0o777, exist_ok=True) model_path = 'test_checkpoints/' + \ experiment_name + \ '.pt' torch.save( { 'epochs': ep, 'model_state_dict': emb_net_test.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss:': epoch_loss, }, model_path) # Compute the embedding of the data points. bin_array_test = data_utils.get_binary_array(test_n, digits) test_embeddings = emb_net_test( torch.Tensor(bin_array_test).cuda().float()).cpu().detach().numpy() train_embeddings = emb_net_train( torch.Tensor(bin_array).cuda().float()).cpu().detach().numpy() unseen_triplet_error, _ = data_utils.triplet_error_unseen( test_embeddings, train_embeddings, unseen_triplets) logger.info('Unseen triplet error is ' + str(unseen_triplet_error)) return unseen_triplet_error
class TrainValProcess(): def __init__(self): self.net = ET_Net() if (ARGS['weight']): self.net.load_state_dict(torch.load(ARGS['weight'])) else: self.net.load_encoder_weight() if (ARGS['gpu']): self.net = DataParallel(module=self.net.cuda()) self.train_dataset = get_dataset(dataset_name=ARGS['dataset'], part='train') self.val_dataset = get_dataset(dataset_name=ARGS['dataset'], part='val') self.optimizer = Adam(self.net.parameters(), lr=ARGS['lr']) # Use / to get an approximate result, // to get an accurate result total_iters = len( self.train_dataset) // ARGS['batch_size'] * ARGS['num_epochs'] self.lr_scheduler = LambdaLR( self.optimizer, lr_lambda=lambda iter: (1 - iter / total_iters)**ARGS['scheduler_power']) self.writer = SummaryWriter() def train(self, epoch): start = time.time() self.net.train() train_dataloader = DataLoader(self.train_dataset, batch_size=ARGS['batch_size'], shuffle=False) epoch_loss = 0. for batch_index, items in enumerate(train_dataloader): images, labels, edges = items['image'], items['label'], items[ 'edge'] images = images.float() labels = labels.long() edges = edges.long() if ARGS['gpu']: labels = labels.cuda() images = images.cuda() edges = edges.cuda() self.optimizer.zero_grad() outputs_edge, outputs = self.net(images) # print('output edge min:', outputs_edge[0, 1].min(), ' max: ', outputs_edge[0, 1].max()) # plt.imshow(outputs_edge[0, 1].detach().cpu().numpy() * 255, cmap='gray') # plt.show() loss_edge = lovasz_softmax(outputs_edge, edges) # Lovasz-Softmax loss loss_seg = lovasz_softmax(outputs, labels) # loss = ARGS['combine_alpha'] * loss_seg + ( 1 - ARGS['combine_alpha']) * loss_edge loss.backward() self.optimizer.step() self.lr_scheduler.step() n_iter = (epoch - 1) * len(train_dataloader) + batch_index + 1 pred = torch.max(outputs, dim=1)[1] iou = torch.sum(pred & labels) / (torch.sum(pred | labels) + 1e-6) # print('edge min:', edges.min(), ' max: ', edges.max()) # print('output edge min:', outputs_edge.min(), ' max: ', outputs_edge.max()) print( 'Training Epoch: {epoch} [{trained_samples}/{total_samples}]\tL_edge: {:0.4f}\tL_seg: {:0.4f}\tL_all: {:0.4f}\tIoU: {:0.4f}\tLR: {:0.4f}' .format(loss_edge.item(), loss_seg.item(), loss.item(), iou.item(), self.optimizer.param_groups[0]['lr'], epoch=epoch, trained_samples=batch_index * ARGS['batch_size'], total_samples=len(train_dataloader.dataset))) epoch_loss += loss.item() # update training loss for each iteration # self.writer.add_scalar('Train/loss', loss.item(), n_iter) for name, param in self.net.named_parameters(): layer, attr = os.path.splitext(name) attr = attr[1:] self.writer.add_histogram("{}/{}".format(layer, attr), param, epoch) epoch_loss /= len(train_dataloader) self.writer.add_scalar('Train/loss', epoch_loss, epoch) finish = time.time() print('epoch {} training time consumed: {:.2f}s'.format( epoch, finish - start)) def validate(self, epoch): start = time.time() self.net.eval() val_batch_size = min(ARGS['batch_size'], len(self.val_dataset)) val_dataloader = DataLoader(self.val_dataset, batch_size=val_batch_size) epoch_loss = 0. for batch_index, items in enumerate(val_dataloader): images, labels, edges = items['image'], items['label'], items[ 'edge'] # print('label min:', labels[0].min(), ' max: ', labels[0].max()) # print('edge min:', labels[0].min(), ' max: ', labels[0].max()) if ARGS['gpu']: labels = labels.cuda() images = images.cuda() edges = edges.cuda() print('image shape:', images.size()) with torch.no_grad(): outputs_edge, outputs = self.net(images) loss_edge = lovasz_softmax(outputs_edge, edges) # Lovasz-Softmax loss loss_seg = lovasz_softmax(outputs, labels) # loss = ARGS['combine_alpha'] * loss_seg + ( 1 - ARGS['combine_alpha']) * loss_edge pred = torch.max(outputs, dim=1)[1] iou = torch.sum(pred & labels) / (torch.sum(pred | labels) + 1e-6) print( 'Validating Epoch: {epoch} [{val_samples}/{total_samples}]\tLoss: {:0.4f}\tIoU: {:0.4f}' .format(loss.item(), iou.item(), epoch=epoch, val_samples=batch_index * val_batch_size, total_samples=len(val_dataloader.dataset))) epoch_loss += loss # update training loss for each iteration # self.writer.add_scalar('Train/loss', loss.item(), n_iter) epoch_loss /= len(val_dataloader) self.writer.add_scalar('Val/loss', epoch_loss, epoch) finish = time.time() print('epoch {} training time consumed: {:.2f}s'.format( epoch, finish - start)) def train_val(self): print('Begin training and validating:') for epoch in range(ARGS['num_epochs']): self.train(epoch) self.validate(epoch) self.net.state_dict() print(f'Finish training and validating epoch #{epoch+1}') if (epoch + 1) % ARGS['epoch_save'] == 0: os.makedirs(ARGS['weight_save_folder'], exist_ok=True) torch.save( self.net.state_dict(), os.path.join(ARGS['weight_save_folder'], f'epoch_{epoch+1}.pth')) print(f'Model saved for epoch #{epoch+1}.') print('Finish training and validating.')