def create_ema_model(model): ema_model = Res_Deeplab(num_classes=num_classes) for param in ema_model.parameters(): param.detach_() mp = list(model.parameters()) mcp = list(ema_model.parameters()) n = len(mp) for i in range(0, n): mcp[i].data[:] = mp[i].data[:].clone() if len(gpus) > 1: if use_sync_batchnorm: ema_model = convert_model(ema_model) ema_model = DataParallelWithCallback(ema_model, device_ids=gpus) else: ema_model = torch.nn.DataParallel(ema_model, device_ids=gpus) return ema_model
def create_ema_model(model): #ema_model = getattr(models, config['arch']['type'])(self.train_loader.dataset.num_classes, **config['arch']['args']).to(self.device) ema_model = Res_Deeplab(num_classes=num_classes) for param in ema_model.parameters(): param.detach_() mp = list(model.parameters()) mcp = list(ema_model.parameters()) n = len(mp) for i in range(0, n): mcp[i].data[:] = mp[i].data[:].clone() #_, availble_gpus = self._get_available_devices(self.config['n_gpu']) #ema_model = torch.nn.DataParallel(ema_model, device_ids=availble_gpus) if len(gpus) > 1: #return torch.nn.DataParallel(ema_model, device_ids=gpus) if use_sync_batchnorm: ema_model = convert_model(ema_model) ema_model = DataParallelWithCallback(ema_model, device_ids=gpus) else: ema_model = torch.nn.DataParallel(ema_model, device_ids=gpus) return ema_model
def __init__(self, model, loss, resume, config, train_loader, val_loader=None, train_logger=None): self.model = model self.loss = loss self.config = config self.train_loader = train_loader self.val_loader = val_loader self.train_logger = train_logger self.logger = logging.getLogger(self.__class__.__name__) self.do_validation = self.config['trainer']['val'] self.start_epoch = 1 self.improved = False # SETTING THE DEVICE self.device, availble_gpus = self._get_available_devices( self.config['n_gpu']) self.model.loss = loss if config["use_synch_bn"]: self.model = convert_model(self.model) self.model = DataParallelWithCallback(self.model, device_ids=availble_gpus) else: self.model = torch.nn.DataParallel(self.model, device_ids=availble_gpus) self.model.cuda() # CONFIGS cfg_trainer = self.config['trainer'] self.epochs = cfg_trainer['epochs'] self.save_period = cfg_trainer['save_period'] # OPTIMIZER if self.config['optimizer']['differential_lr']: if isinstance(self.model, torch.nn.DataParallel): trainable_params = [{ 'params': filter(lambda p: p.requires_grad, self.model.module.get_decoder_params()) }, { 'params': filter(lambda p: p.requires_grad, self.model.module.get_backbone_params()), 'lr': config['optimizer']['args']['lr'] / 10 }] else: trainable_params = [{ 'params': filter(lambda p: p.requires_grad, self.model.get_decoder_params()) }, { 'params': filter(lambda p: p.requires_grad, self.model.get_backbone_params()), 'lr': config['optimizer']['args']['lr'] / 10 }] else: trainable_params = filter(lambda p: p.requires_grad, self.model.parameters()) self.optimizer = get_instance(torch.optim, 'optimizer', config, trainable_params) self.lr_scheduler = getattr(utils.lr_scheduler, config['lr_scheduler']['type'])( self.optimizer, self.epochs, len(train_loader)) #self.lr_scheduler = getattr(torch.optim.lr_scheduler, config['lr_scheduler']['type'])(self.optimizer, **config['lr_scheduler']['args']) # MONITORING self.monitor = cfg_trainer.get('monitor', 'off') if self.monitor == 'off': self.mnt_mode = 'off' self.mnt_best = 0 else: self.mnt_mode, self.mnt_metric = self.monitor.split() assert self.mnt_mode in ['min', 'max'] self.mnt_best = -math.inf if self.mnt_mode == 'max' else math.inf self.early_stoping = cfg_trainer.get('early_stop', math.inf) # CHECKPOINTS & TENSOBOARD start_time = datetime.datetime.now().strftime('%m-%d_%H-%M') self.checkpoint_dir = os.path.join(cfg_trainer['save_dir'], self.config['name'], start_time) helpers.dir_exists(self.checkpoint_dir) config_save_path = os.path.join(self.checkpoint_dir, 'config.json') with open(config_save_path, 'w') as handle: json.dump(self.config, handle, indent=4, sort_keys=True) writer_dir = os.path.join(cfg_trainer['log_dir'], self.config['name'], start_time) self.writer = tensorboard.SummaryWriter(writer_dir) if resume: self._resume_checkpoint(resume)
class BaseTrainer: def __init__(self, model, loss, resume, config, train_loader, val_loader=None, train_logger=None): self.model = model self.loss = loss self.config = config self.train_loader = train_loader self.val_loader = val_loader self.train_logger = train_logger self.logger = logging.getLogger(self.__class__.__name__) self.do_validation = self.config['trainer']['val'] self.start_epoch = 1 self.improved = False # SETTING THE DEVICE self.device, availble_gpus = self._get_available_devices( self.config['n_gpu']) self.model.loss = loss if config["use_synch_bn"]: self.model = convert_model(self.model) self.model = DataParallelWithCallback(self.model, device_ids=availble_gpus) else: self.model = torch.nn.DataParallel(self.model, device_ids=availble_gpus) self.model.cuda() # CONFIGS cfg_trainer = self.config['trainer'] self.epochs = cfg_trainer['epochs'] self.save_period = cfg_trainer['save_period'] # OPTIMIZER if self.config['optimizer']['differential_lr']: if isinstance(self.model, torch.nn.DataParallel): trainable_params = [{ 'params': filter(lambda p: p.requires_grad, self.model.module.get_decoder_params()) }, { 'params': filter(lambda p: p.requires_grad, self.model.module.get_backbone_params()), 'lr': config['optimizer']['args']['lr'] / 10 }] else: trainable_params = [{ 'params': filter(lambda p: p.requires_grad, self.model.get_decoder_params()) }, { 'params': filter(lambda p: p.requires_grad, self.model.get_backbone_params()), 'lr': config['optimizer']['args']['lr'] / 10 }] else: trainable_params = filter(lambda p: p.requires_grad, self.model.parameters()) self.optimizer = get_instance(torch.optim, 'optimizer', config, trainable_params) self.lr_scheduler = getattr(utils.lr_scheduler, config['lr_scheduler']['type'])( self.optimizer, self.epochs, len(train_loader)) #self.lr_scheduler = getattr(torch.optim.lr_scheduler, config['lr_scheduler']['type'])(self.optimizer, **config['lr_scheduler']['args']) # MONITORING self.monitor = cfg_trainer.get('monitor', 'off') if self.monitor == 'off': self.mnt_mode = 'off' self.mnt_best = 0 else: self.mnt_mode, self.mnt_metric = self.monitor.split() assert self.mnt_mode in ['min', 'max'] self.mnt_best = -math.inf if self.mnt_mode == 'max' else math.inf self.early_stoping = cfg_trainer.get('early_stop', math.inf) # CHECKPOINTS & TENSOBOARD start_time = datetime.datetime.now().strftime('%m-%d_%H-%M') self.checkpoint_dir = os.path.join(cfg_trainer['save_dir'], self.config['name'], start_time) helpers.dir_exists(self.checkpoint_dir) config_save_path = os.path.join(self.checkpoint_dir, 'config.json') with open(config_save_path, 'w') as handle: json.dump(self.config, handle, indent=4, sort_keys=True) writer_dir = os.path.join(cfg_trainer['log_dir'], self.config['name'], start_time) self.writer = tensorboard.SummaryWriter(writer_dir) if resume: self._resume_checkpoint(resume) def _get_available_devices(self, n_gpu): sys_gpu = torch.cuda.device_count() if sys_gpu == 0: self.logger.warning('No GPUs detected, using the CPU') n_gpu = 0 elif n_gpu > sys_gpu: self.logger.warning( f'Nbr of GPU requested is {n_gpu} but only {sys_gpu} are available' ) n_gpu = sys_gpu device = torch.device('cuda' if n_gpu > 0 else 'cpu') self.logger.info(f'Detected GPUs: {sys_gpu} Requested: {n_gpu}') available_gpus = list(range(n_gpu)) return device, available_gpus def train(self): for epoch in range(self.start_epoch, self.epochs + 1): # RUN TRAIN (AND VAL) results = self._train_epoch(epoch) self.lr_scheduler.step() if self.do_validation and epoch % self.config['trainer'][ 'val_per_epochs'] == 0: results = self._valid_epoch(epoch) # LOGGING INFO self.logger.info(f'\n ## Info for epoch {epoch} ## ') for k, v in results.items(): self.logger.info(f' {str(k):15s}: {v}') if self.train_logger is not None: log = {'epoch': epoch, **results} self.train_logger.add_entry(log) # CHECKING IF THIS IS THE BEST MODEL (ONLY FOR VAL) if self.mnt_mode != 'off' and epoch % self.config['trainer'][ 'val_per_epochs'] == 0: try: if self.mnt_mode == 'min': self.improved = (log[self.mnt_metric] < self.mnt_best) else: self.improved = (log[self.mnt_metric] > self.mnt_best) except KeyError: self.logger.warning( f'The metrics being tracked ({self.mnt_metric}) has not been calculated. Training stops.' ) break if self.improved: self.mnt_best = log[self.mnt_metric] self.not_improved_count = 0 else: self.not_improved_count += 1 if self.not_improved_count > self.early_stoping: self.logger.info( f'\nPerformance didn\'t improve for {self.early_stoping} epochs' ) self.logger.warning('Training Stoped') break # SAVE CHECKPOINT if epoch % self.save_period == 0: self._save_checkpoint(epoch, save_best=self.improved) def _save_checkpoint(self, epoch, save_best=False): state = { 'arch': type(self.model).__name__, 'epoch': epoch, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'monitor_best': self.mnt_best, 'config': self.config } filename = os.path.join(self.checkpoint_dir, f'checkpoint-epoch{epoch}.pth') self.logger.info(f'\nSaving a checkpoint: {filename} ...') torch.save(state, filename) if save_best: filename = os.path.join(self.checkpoint_dir, f'best_model.pth') torch.save(state, filename) self.logger.info("Saving current best: best_model.pth") def _resume_checkpoint(self, resume_path): self.logger.info(f'Loading checkpoint : {resume_path}') checkpoint = torch.load(resume_path) # Load last run info, the model params, the optimizer and the loggers self.start_epoch = checkpoint['epoch'] + 1 self.mnt_best = checkpoint['monitor_best'] self.not_improved_count = 0 if checkpoint['config']['arch'] != self.config['arch']: self.logger.warning({ 'Warning! Current model is not the same as the one in the checkpoint' }) self.model.load_state_dict(checkpoint['state_dict'], strict=False) if checkpoint['config']['optimizer']['type'] != self.config[ 'optimizer']['type']: self.logger.warning({ 'Warning! Current optimizer is not the same as the one in the checkpoint' }) self.optimizer.load_state_dict(checkpoint['optimizer']) # if self.lr_scheduler: # self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) #self.train_logger = checkpoint['logger'] #self.logger.info(f'Checkpoint <{resume_path}> (epoch {self.start_epoch}) was loaded') def _train_epoch(self, epoch): raise NotImplementedError def _valid_epoch(self, epoch): raise NotImplementedError def _eval_metrics(self, output, target): raise NotImplementedError
def main(): torch.cuda.empty_cache() print(config) best_mIoU = 0 if consistency_loss == 'CE': if len(gpus) > 1: unlabeled_loss = torch.nn.DataParallel( CrossEntropyLoss2dPixelWiseWeighted(ignore_index=ignore_label), device_ids=gpus).cuda() else: unlabeled_loss = CrossEntropyLoss2dPixelWiseWeighted().cuda() elif consistency_loss == 'MSE': if len(gpus) > 1: unlabeled_loss = torch.nn.DataParallel(MSELoss2d(), device_ids=gpus).cuda() else: unlabeled_loss = MSELoss2d().cuda() cudnn.enabled = True # create network model = Res_Deeplab(num_classes=num_classes) # load pretrained parameters if restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(restore_from) else: saved_state_dict = torch.load(restore_from) # Copy loaded parameters to model new_params = model.state_dict().copy() for name, param in new_params.items(): if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) model.load_state_dict(new_params) # Initiate ema-model if train_unlabeled: ema_model = create_ema_model(model) ema_model.train() ema_model = ema_model.cuda() else: ema_model = None if len(gpus) > 1: if use_sync_batchnorm: model = convert_model(model) model = DataParallelWithCallback(model, device_ids=gpus) else: model = torch.nn.DataParallel(model, device_ids=gpus) model.train() model.cuda() cudnn.benchmark = True if dataset == 'pascal_voc': data_loader = get_loader(dataset) data_path = get_data_path(dataset) train_dataset = data_loader(data_path, crop_size=input_size, scale=random_scale, mirror=random_flip) elif dataset == 'cityscapes': data_loader = get_loader('cityscapes') data_path = get_data_path('cityscapes') if random_crop: data_aug = Compose([RandomCrop_city(input_size)]) else: data_aug = None train_dataset = data_loader(data_path, is_transform=True, augmentations=data_aug, img_size=input_size) train_dataset_size = len(train_dataset) print('dataset size: ', train_dataset_size) partial_size = labeled_samples print('Training on number of samples:', partial_size) if split_id is not None: train_ids = pickle.load(open(split_id, 'rb')) print('loading train ids from {}'.format(split_id)) else: np.random.seed(random_seed) train_ids = np.arange(train_dataset_size) np.random.shuffle(train_ids) train_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size]) trainloader = data.DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers, pin_memory=True) trainloader_iter = iter(trainloader) if train_unlabeled: train_remain_sampler = data.sampler.SubsetRandomSampler( train_ids[partial_size:]) trainloader_remain = data.DataLoader(train_dataset, batch_size=batch_size, sampler=train_remain_sampler, num_workers=1, pin_memory=True) trainloader_remain_iter = iter(trainloader_remain) # Optimizer for segmentation network learning_rate_object = Learning_Rate_Object( config['training']['learning_rate']) if optimizer_type == 'SGD': if len(gpus) > 1: optimizer = optim.SGD( model.module.optim_parameters(learning_rate_object), lr=learning_rate, momentum=momentum, weight_decay=weight_decay) else: optimizer = optim.SGD( model.optim_parameters( learning_rate_object), ## DOES THIS CAUSE THE USERWARNING? lr=learning_rate, momentum=momentum, weight_decay=weight_decay) optimizer.zero_grad() interp = nn.Upsample(size=(input_size[0], input_size[1]), mode='bilinear', align_corners=True) start_iteration = 0 if args.resume: start_iteration, model, optimizer, ema_model = _resume_checkpoint( args.resume, model, optimizer, ema_model) accumulated_loss_l = [] if train_unlabeled: accumulated_loss_u = [] if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) with open(checkpoint_dir + '/config.json', 'w') as handle: json.dump(config, handle, indent=4, sort_keys=False) pickle.dump(train_ids, open(os.path.join(checkpoint_dir, 'train_split.pkl'), 'wb')) epochs_since_start = 0 for i_iter in range(start_iteration, num_iterations): model.train() loss_l_value = 0 if train_unlabeled: loss_u_value = 0 optimizer.zero_grad() if lr_schedule: adjust_learning_rate(optimizer, i_iter) # Training loss for labeled data only try: batch = next(trainloader_iter) if batch[0].shape[0] != batch_size: batch = next(trainloader_iter) except: epochs_since_start = epochs_since_start + 1 print('Epochs since start: ', epochs_since_start) trainloader_iter = iter(trainloader) batch = next(trainloader_iter) weak_parameters = {"flip": 0} images, labels, _, _, _ = batch images = images.cuda() labels = labels.cuda() images, labels = weakTransform(weak_parameters, data=images, target=labels) intermediary_var = model(images) pred = interp(intermediary_var) L_l = loss_calc(pred, labels) if train_unlabeled: try: batch_remain = next(trainloader_remain_iter) if batch_remain[0].shape[0] != batch_size: batch_remain = next(trainloader_remain_iter) except: trainloader_remain_iter = iter(trainloader_remain) batch_remain = next(trainloader_remain_iter) images_remain, _, _, _, _ = batch_remain images_remain = images_remain.cuda() inputs_u_w, _ = weakTransform(weak_parameters, data=images_remain) logits_u_w = interp(ema_model(inputs_u_w)) logits_u_w, _ = weakTransform( getWeakInverseTransformParameters(weak_parameters), data=logits_u_w.detach()) softmax_u_w = torch.softmax(logits_u_w.detach(), dim=1) max_probs, argmax_u_w = torch.max(softmax_u_w, dim=1) if mix_mask == "class": for image_i in range(batch_size): classes = torch.unique(argmax_u_w[image_i]) classes = classes[classes != ignore_label] nclasses = classes.shape[0] classes = (classes[torch.Tensor( np.random.choice(nclasses, int((nclasses - nclasses % 2) / 2), replace=False)).long()]).cuda() if image_i == 0: MixMask = transformmasks.generate_class_mask( argmax_u_w[image_i], classes).unsqueeze(0).cuda() else: MixMask = torch.cat( (MixMask, transformmasks.generate_class_mask( argmax_u_w[image_i], classes).unsqueeze(0).cuda())) elif mix_mask == 'cut': img_size = inputs_u_w.shape[2:4] for image_i in range(batch_size): if image_i == 0: MixMask = torch.from_numpy( transformmasks.generate_cutout_mask( img_size)).unsqueeze(0).cuda().float() else: MixMask = torch.cat( (MixMask, torch.from_numpy( transformmasks.generate_cutout_mask( img_size)).unsqueeze(0).cuda().float())) elif mix_mask == "cow": img_size = inputs_u_w.shape[2:4] sigma_min = 8 sigma_max = 32 p_min = 0.5 p_max = 0.5 for image_i in range(batch_size): sigma = np.exp( np.random.uniform(np.log(sigma_min), np.log(sigma_max))) # Random sigma p = np.random.uniform(p_min, p_max) # Random p if image_i == 0: MixMask = torch.from_numpy( transformmasks.generate_cow_mask( img_size, sigma, p, seed=None)).unsqueeze(0).cuda().float() else: MixMask = torch.cat( (MixMask, torch.from_numpy( transformmasks.generate_cow_mask( img_size, sigma, p, seed=None)).unsqueeze(0).cuda().float())) elif mix_mask == None: MixMask = torch.ones((inputs_u_w.shape)).cuda() strong_parameters = {"Mix": MixMask} if random_flip: strong_parameters["flip"] = random.randint(0, 1) else: strong_parameters["flip"] = 0 if color_jitter: strong_parameters["ColorJitter"] = random.uniform(0, 1) else: strong_parameters["ColorJitter"] = 0 if gaussian_blur: strong_parameters["GaussianBlur"] = random.uniform(0, 1) else: strong_parameters["GaussianBlur"] = 0 inputs_u_s, _ = strongTransform(strong_parameters, data=images_remain) logits_u_s = interp(model(inputs_u_s)) softmax_u_w_mixed, _ = strongTransform(strong_parameters, data=softmax_u_w) max_probs, pseudo_label = torch.max(softmax_u_w_mixed, dim=1) if pixel_weight == "threshold_uniform": unlabeled_weight = torch.sum( max_probs.ge(0.968).long() == 1).item() / np.size( np.array(pseudo_label.cpu())) pixelWiseWeight = unlabeled_weight * torch.ones( max_probs.shape).cuda() elif pixel_weight == "threshold": pixelWiseWeight = max_probs.ge(0.968).long().cuda() elif pixel_weight == 'sigmoid': max_iter = 10000 pixelWiseWeight = sigmoid_ramp_up( i_iter, max_iter) * torch.ones(max_probs.shape).cuda() elif pixel_weight == False: pixelWiseWeight = torch.ones(max_probs.shape).cuda() if consistency_loss == 'CE': L_u = consistency_weight * unlabeled_loss( logits_u_s, pseudo_label, pixelWiseWeight) elif consistency_loss == 'MSE': unlabeled_weight = torch.sum( max_probs.ge(0.968).long() == 1).item() / np.size( np.array(pseudo_label.cpu())) #softmax_u_w_mixed = torch.cat((softmax_u_w_mixed[1].unsqueeze(0),softmax_u_w_mixed[0].unsqueeze(0))) L_u = consistency_weight * unlabeled_weight * unlabeled_loss( logits_u_s, softmax_u_w_mixed) loss = L_l + L_u else: loss = L_l if len(gpus) > 1: loss = loss.mean() loss_l_value += L_l.mean().item() if train_unlabeled: loss_u_value += L_u.mean().item() else: loss_l_value += L_l.item() if train_unlabeled: loss_u_value += L_u.item() loss.backward() optimizer.step() # update Mean teacher network if ema_model is not None: alpha_teacher = 0.99 ema_model = update_ema_variables(ema_model=ema_model, model=model, alpha_teacher=alpha_teacher, iteration=i_iter) if train_unlabeled: print('iter = {0:6d}/{1:6d}, loss_l = {2:.3f}, loss_u = {3:.3f}'. format(i_iter, num_iterations, loss_l_value, loss_u_value)) else: print('iter = {0:6d}/{1:6d}, loss_l = {2:.3f}'.format( i_iter, num_iterations, loss_l_value)) if i_iter % save_checkpoint_every == 0 and i_iter != 0: _save_checkpoint(i_iter, model, optimizer, config, ema_model) if use_tensorboard: if 'tensorboard_writer' not in locals(): tensorboard_writer = tensorboard.SummaryWriter(log_dir, flush_secs=30) accumulated_loss_l.append(loss_l_value) if train_unlabeled: accumulated_loss_u.append(loss_u_value) if i_iter % log_per_iter == 0 and i_iter != 0: tensorboard_writer.add_scalar('Training/Supervised loss', np.mean(accumulated_loss_l), i_iter) accumulated_loss_l = [] if train_unlabeled: tensorboard_writer.add_scalar('Training/Unsupervised loss', np.mean(accumulated_loss_u), i_iter) accumulated_loss_u = [] if i_iter % val_per_iter == 0 and i_iter != 0: model.eval() mIoU, eval_loss = evaluate(model, dataset, ignore_label=ignore_label, input_size=(512, 1024), save_dir=checkpoint_dir) model.train() if mIoU > best_mIoU and save_best_model: best_mIoU = mIoU _save_checkpoint(i_iter, model, optimizer, config, ema_model, save_best=True) if use_tensorboard: tensorboard_writer.add_scalar('Validation/mIoU', mIoU, i_iter) tensorboard_writer.add_scalar('Validation/Loss', eval_loss, i_iter) if save_unlabeled_images and train_unlabeled and i_iter % save_checkpoint_every == 0: # Saves two mixed images and the corresponding prediction save_image(inputs_u_s[0].cpu(), i_iter, 'input1', palette.CityScpates_palette) save_image(inputs_u_s[1].cpu(), i_iter, 'input2', palette.CityScpates_palette) _, pred_u_s = torch.max(logits_u_s, dim=1) save_image(pred_u_s[0].cpu(), i_iter, 'pred1', palette.CityScpates_palette) save_image(pred_u_s[1].cpu(), i_iter, 'pred2', palette.CityScpates_palette) _save_checkpoint(num_iterations, model, optimizer, config, ema_model) model.eval() mIoU, val_loss = evaluate(model, dataset, ignore_label=ignore_label, input_size=(512, 1024), save_dir=checkpoint_dir) model.train() if mIoU > best_mIoU and save_best_model: best_mIoU = mIoU _save_checkpoint(i_iter, model, optimizer, config, ema_model, save_best=True) if use_tensorboard: tensorboard_writer.add_scalar('Validation/mIoU', mIoU, i_iter) tensorboard_writer.add_scalar('Validation/Loss', val_loss, i_iter) end = timeit.default_timer() print('Total time: ' + str(end - start) + ' seconds')
def main(): print(config) best_mIoU = 0 if consistency_loss == 'MSE': if len(gpus) > 1: unlabeled_loss = torch.nn.DataParallel(MSELoss2d(), device_ids=gpus).cuda() else: unlabeled_loss = MSELoss2d().cuda() elif consistency_loss == 'CE': if len(gpus) > 1: unlabeled_loss = torch.nn.DataParallel( CrossEntropyLoss2dPixelWiseWeighted(ignore_index=ignore_label), device_ids=gpus).cuda() else: unlabeled_loss = CrossEntropyLoss2dPixelWiseWeighted( ignore_index=ignore_label).cuda() cudnn.enabled = True # create network model = Res_Deeplab(num_classes=num_classes) # load pretrained parameters #saved_state_dict = torch.load(args.restore_from) # load pretrained parameters if restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(restore_from) else: saved_state_dict = torch.load(restore_from) # Copy loaded parameters to model new_params = model.state_dict().copy() for name, param in new_params.items(): if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) model.load_state_dict(new_params) # init ema-model if train_unlabeled: ema_model = create_ema_model(model) ema_model.train() ema_model = ema_model.cuda() else: ema_model = None if len(gpus) > 1: if use_sync_batchnorm: model = convert_model(model) model = DataParallelWithCallback(model, device_ids=gpus) else: model = torch.nn.DataParallel(model, device_ids=gpus) model.train() model.cuda() cudnn.benchmark = True data_loader = get_loader(config['dataset']) # data_path = get_data_path(config['dataset']) # if random_crop: # data_aug = Compose([RandomCrop_city(input_size)]) # else: # data_aug = None data_aug = Compose([RandomHorizontallyFlip()]) if dataset == 'cityscapes': train_dataset = data_loader(data_path, is_transform=True, augmentations=data_aug, img_size=input_size, img_mean=IMG_MEAN) elif dataset == 'multiview': # adaption data data_path = '/tmp/tcn_data/texture_multibot_push_left10050/videos/train_adaptation' train_dataset = data_loader(data_path, is_transform=True, view_idx=0, number_views=1, load_seg_mask=False, augmentations=data_aug, img_size=input_size, img_mean=IMG_MEAN) train_dataset_size = len(train_dataset) print('dataset size: ', train_dataset_size) if labeled_samples is None: trainloader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) trainloader_remain = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) trainloader_remain_iter = iter(trainloader_remain) else: partial_size = labeled_samples print('Training on number of samples:', partial_size) np.random.seed(random_seed) trainloader_remain = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) trainloader_remain_iter = iter(trainloader_remain) #New loader for Domain transfer # if random_crop: # data_aug = Compose([RandomCrop_gta(input_size)]) # else: # data_aug = None # SUPERVSIED DATA data_path = '/tmp/tcn_data/texture_multibot_push_left10050/videos/train_adaptation' data_aug = Compose([RandomHorizontallyFlip()]) if dataset == 'multiview': train_dataset = data_loader(data_path, is_transform=True, view_idx=0, number_views=1, load_seg_mask=True, augmentations=data_aug, img_size=input_size, img_mean=IMG_MEAN) else: data_loader = get_loader('gta') data_path = get_data_path('gta') train_dataset = data_loader(data_path, list_path='./data/gta5_list/train.txt', augmentations=data_aug, img_size=(1280, 720), mean=IMG_MEAN) trainloader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) # training loss for labeled data only trainloader_iter = iter(trainloader) print('gta size:', len(trainloader)) #Load new data for domain_transfer # optimizer for segmentation network learning_rate_object = Learning_Rate_Object( config['training']['learning_rate']) if optimizer_type == 'SGD': if len(gpus) > 1: optimizer = optim.SGD( model.module.optim_parameters(learning_rate_object), lr=learning_rate, momentum=momentum, weight_decay=weight_decay) else: optimizer = optim.SGD(model.optim_parameters(learning_rate_object), lr=learning_rate, momentum=momentum, weight_decay=weight_decay) elif optimizer_type == 'Adam': if len(gpus) > 1: optimizer = optim.Adam( model.module.optim_parameters(learning_rate_object), lr=learning_rate, momentum=momentum, weight_decay=weight_decay) else: optimizer = optim.Adam( model.optim_parameters(learning_rate_object), lr=learning_rate, weight_decay=weight_decay) optimizer.zero_grad() interp = nn.Upsample(size=(input_size[0], input_size[1]), mode='bilinear', align_corners=True) start_iteration = 0 if args.resume: start_iteration, model, optimizer, ema_model = _resume_checkpoint( args.resume, model, optimizer, ema_model) accumulated_loss_l = [] accumulated_loss_u = [] if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) with open(checkpoint_dir + '/config.json', 'w') as handle: json.dump(config, handle, indent=4, sort_keys=True) epochs_since_start = 0 for i_iter in range(start_iteration, num_iterations): model.train() loss_u_value = 0 loss_l_value = 0 optimizer.zero_grad() if lr_schedule: adjust_learning_rate(optimizer, i_iter) # training loss for labeled data only try: batch = next(trainloader_iter) if batch[0].shape[0] != batch_size: batch = next(trainloader_iter) except: epochs_since_start = epochs_since_start + 1 print('Epochs since start: ', epochs_since_start) trainloader_iter = iter(trainloader) batch = next(trainloader_iter) #if random_flip: # weak_parameters={"flip":random.randint(0,1)} #else: weak_parameters = {"flip": 0} images, labels, _, _ = batch images = images.cuda() labels = labels.cuda().long() #images, labels = weakTransform(weak_parameters, data = images, target = labels) pred = interp(model(images)) L_l = loss_calc(pred, labels) # Cross entropy loss for labeled data #L_l = torch.Tensor([0.0]).cuda() if train_unlabeled: try: batch_remain = next(trainloader_remain_iter) if batch_remain[0].shape[0] != batch_size: batch_remain = next(trainloader_remain_iter) except: trainloader_remain_iter = iter(trainloader_remain) batch_remain = next(trainloader_remain_iter) images_remain, *_ = batch_remain images_remain = images_remain.cuda() inputs_u_w, _ = weakTransform(weak_parameters, data=images_remain) #inputs_u_w = inputs_u_w.clone() logits_u_w = interp(ema_model(inputs_u_w)) logits_u_w, _ = weakTransform( getWeakInverseTransformParameters(weak_parameters), data=logits_u_w.detach()) pseudo_label = torch.softmax(logits_u_w.detach(), dim=1) max_probs, targets_u_w = torch.max(pseudo_label, dim=1) if mix_mask == "class": for image_i in range(batch_size): classes = torch.unique(labels[image_i]) #classes=classes[classes!=ignore_label] nclasses = classes.shape[0] #if nclasses > 0: classes = (classes[torch.Tensor( np.random.choice(nclasses, int((nclasses + nclasses % 2) / 2), replace=False)).long()]).cuda() if image_i == 0: MixMask0 = transformmasks.generate_class_mask( labels[image_i], classes).unsqueeze(0).cuda() else: MixMask1 = transformmasks.generate_class_mask( labels[image_i], classes).unsqueeze(0).cuda() elif mix_mask == None: MixMask = torch.ones((inputs_u_w.shape)) strong_parameters = {"Mix": MixMask0} if random_flip: strong_parameters["flip"] = random.randint(0, 1) else: strong_parameters["flip"] = 0 if color_jitter: strong_parameters["ColorJitter"] = random.uniform(0, 1) else: strong_parameters["ColorJitter"] = 0 if gaussian_blur: strong_parameters["GaussianBlur"] = random.uniform(0, 1) else: strong_parameters["GaussianBlur"] = 0 inputs_u_s0, _ = strongTransform( strong_parameters, data=torch.cat( (images[0].unsqueeze(0), images_remain[0].unsqueeze(0)))) strong_parameters["Mix"] = MixMask1 inputs_u_s1, _ = strongTransform( strong_parameters, data=torch.cat( (images[1].unsqueeze(0), images_remain[1].unsqueeze(0)))) inputs_u_s = torch.cat((inputs_u_s0, inputs_u_s1)) logits_u_s = interp(model(inputs_u_s)) strong_parameters["Mix"] = MixMask0 _, targets_u0 = strongTransform(strong_parameters, target=torch.cat( (labels[0].unsqueeze(0), targets_u_w[0].unsqueeze(0)))) strong_parameters["Mix"] = MixMask1 _, targets_u1 = strongTransform(strong_parameters, target=torch.cat( (labels[1].unsqueeze(0), targets_u_w[1].unsqueeze(0)))) targets_u = torch.cat((targets_u0, targets_u1)).long() if pixel_weight == "threshold_uniform": unlabeled_weight = torch.sum( max_probs.ge(0.968).long() == 1).item() / np.size( np.array(targets_u.cpu())) pixelWiseWeight = unlabeled_weight * torch.ones( max_probs.shape).cuda() elif pixel_weight == "threshold": pixelWiseWeight = max_probs.ge(0.968).float().cuda() elif pixel_weight == False: pixelWiseWeight = torch.ones(max_probs.shape).cuda() onesWeights = torch.ones((pixelWiseWeight.shape)).cuda() strong_parameters["Mix"] = MixMask0 _, pixelWiseWeight0 = strongTransform( strong_parameters, target=torch.cat((onesWeights[0].unsqueeze(0), pixelWiseWeight[0].unsqueeze(0)))) strong_parameters["Mix"] = MixMask1 _, pixelWiseWeight1 = strongTransform( strong_parameters, target=torch.cat((onesWeights[1].unsqueeze(0), pixelWiseWeight[1].unsqueeze(0)))) pixelWiseWeight = torch.cat( (pixelWiseWeight0, pixelWiseWeight1)).cuda() if consistency_loss == 'MSE': unlabeled_weight = torch.sum( max_probs.ge(0.968).long() == 1).item() / np.size( np.array(targets_u.cpu())) #pseudo_label = torch.cat((pseudo_label[1].unsqueeze(0),pseudo_label[0].unsqueeze(0))) L_u = consistency_weight * unlabeled_weight * unlabeled_loss( logits_u_s, pseudo_label) elif consistency_loss == 'CE': L_u = consistency_weight * unlabeled_loss( logits_u_s, targets_u, pixelWiseWeight) loss = L_l + L_u else: loss = L_l if len(gpus) > 1: #print('before mean = ',loss) loss = loss.mean() #print('after mean = ',loss) loss_l_value += L_l.mean().item() if train_unlabeled: loss_u_value += L_u.mean().item() else: loss_l_value += L_l.item() if train_unlabeled: loss_u_value += L_u.item() loss.backward() optimizer.step() # update Mean teacher network if ema_model is not None: alpha_teacher = 0.99 ema_model = update_ema_variables(ema_model=ema_model, model=model, alpha_teacher=alpha_teacher, iteration=i_iter) print( 'iter = {0:6d}/{1:6d}, loss_l = {2:.3f}, loss_u = {3:.3f}'.format( i_iter, num_iterations, loss_l_value, loss_u_value)) if i_iter % save_checkpoint_every == 0 and i_iter != 0: if epochs_since_start * len(trainloader) < save_checkpoint_every: _save_checkpoint(i_iter, model, optimizer, config, ema_model, overwrite=False) else: _save_checkpoint(i_iter, model, optimizer, config, ema_model) if config['utils']['tensorboard']: if 'tensorboard_writer' not in locals(): tensorboard_writer = tensorboard.SummaryWriter(log_dir, flush_secs=30) accumulated_loss_l.append(loss_l_value) if train_unlabeled: accumulated_loss_u.append(loss_u_value) if i_iter % log_per_iter == 0 and i_iter != 0: tensorboard_writer.add_scalar('Training/Supervised loss', np.mean(accumulated_loss_l), i_iter) accumulated_loss_l = [] if train_unlabeled: tensorboard_writer.add_scalar('Training/Unsupervised loss', np.mean(accumulated_loss_u), i_iter) accumulated_loss_u = [] if i_iter % val_per_iter == 0 and i_iter != 0: model.eval() if dataset == 'cityscapes': mIoU, eval_loss = evaluate(model, dataset, ignore_label=250, input_size=(512, 1024), save_dir=checkpoint_dir) elif dataset == 'multiview': mIoU, eval_loss = evaluate(model, dataset, ignore_label=255, input_size=(300, 300), save_dir=checkpoint_dir) else: print('erro dataset: {}'.format(dataset)) model.train() if mIoU > best_mIoU and save_best_model: best_mIoU = mIoU _save_checkpoint(i_iter, model, optimizer, config, ema_model, save_best=True) if config['utils']['tensorboard']: tensorboard_writer.add_scalar('Validation/mIoU', mIoU, i_iter) tensorboard_writer.add_scalar('Validation/Loss', eval_loss, i_iter) print('iter {}, mIoU: {}'.format(mIoU, i_iter)) if save_unlabeled_images and train_unlabeled and i_iter % save_checkpoint_every == 0: # Saves two mixed images and the corresponding prediction save_image(inputs_u_s[0].cpu(), i_iter, 'input1', palette.CityScpates_palette) save_image(inputs_u_s[1].cpu(), i_iter, 'input2', palette.CityScpates_palette) _, pred_u_s = torch.max(logits_u_s, dim=1) save_image(pred_u_s[0].cpu(), i_iter, 'pred1', palette.CityScpates_palette) save_image(pred_u_s[1].cpu(), i_iter, 'pred2', palette.CityScpates_palette) _save_checkpoint(num_iterations, model, optimizer, config, ema_model) model.eval() if dataset == 'cityscapes': mIoU, val_loss = evaluate(model, dataset, ignore_label=250, input_size=(512, 1024), save_dir=checkpoint_dir) elif dataset == 'multiview': mIoU, val_loss = evaluate(model, dataset, ignore_label=255, input_size=(300, 300), save_dir=checkpoint_dir) else: print('erro dataset: {}'.format(dataset)) model.train() if mIoU > best_mIoU and save_best_model: best_mIoU = mIoU _save_checkpoint(i_iter, model, optimizer, config, ema_model, save_best=True) if config['utils']['tensorboard']: tensorboard_writer.add_scalar('Validation/mIoU', mIoU, i_iter) tensorboard_writer.add_scalar('Validation/Loss', val_loss, i_iter) end = timeit.default_timer() print('Total time: ' + str(end - start) + 'seconds')
def __init__( self, model, loss, resume, config, train_loader, val_loader=None, train_logger=None, ): self.model = model self.loss = loss self.config = config self.train_loader = train_loader self.val_loader = val_loader self.train_logger = train_logger self.logger = logging.getLogger(self.__class__.__name__) self.do_validation = self.config["trainer"]["val"] self.start_epoch = 1 self.improved = False # SETTING THE DEVICE self.device, availble_gpus = self._get_available_devices( self.config["n_gpu"]) if config["use_synch_bn"]: self.model = convert_model(self.model) self.model = DataParallelWithCallback(self.model, device_ids=availble_gpus) else: self.model = torch.nn.DataParallel(self.model, device_ids=availble_gpus) self.model.to(self.device) # CONFIGS cfg_trainer = self.config["trainer"] self.epochs = cfg_trainer["epochs"] self.save_period = cfg_trainer["save_period"] # OPTIMIZER if self.config["optimizer"]["differential_lr"]: if isinstance(self.model, torch.nn.DataParallel): trainable_params = [ { "params": filter( lambda p: p.requires_grad, self.model.module.get_decoder_params(), ) }, { "params": filter( lambda p: p.requires_grad, self.model.module.get_backbone_params(), ), "lr": config["optimizer"]["args"]["lr"] / 10, }, ] else: trainable_params = [ { "params": filter(lambda p: p.requires_grad, self.model.get_decoder_params()) }, { "params": filter(lambda p: p.requires_grad, self.model.get_backbone_params()), "lr": config["optimizer"]["args"]["lr"] / 10, }, ] else: trainable_params = filter(lambda p: p.requires_grad, self.model.parameters()) self.optimizer = get_instance(torch.optim, "optimizer", config, trainable_params) self.lr_scheduler = getattr(utils.lr_scheduler, config["lr_scheduler"]["type"])( self.optimizer, self.epochs, len(train_loader)) # MONITORING self.monitor = cfg_trainer.get("monitor", "off") if self.monitor == "off": self.mnt_mode = "off" self.mnt_best = 0 else: self.mnt_mode, self.mnt_metric = self.monitor.split() assert self.mnt_mode in ["min", "max"] self.mnt_best = -math.inf if self.mnt_mode == "max" else math.inf self.early_stoping = cfg_trainer.get("early_stop", math.inf) # CHECKPOINTS & TENSOBOARD start_time = datetime.datetime.now().strftime("%m-%d_%H-%M") self.checkpoint_dir = os.path.join(cfg_trainer["save_dir"], self.config["name"], start_time) helpers.dir_exists(self.checkpoint_dir) config_save_path = os.path.join(self.checkpoint_dir, "config.json") with open(config_save_path, "w") as handle: json.dump(self.config, handle, indent=4, sort_keys=True) writer_dir = os.path.join(cfg_trainer["log_dir"], self.config["name"], start_time) self.writer = tensorboard.SummaryWriter(writer_dir) if resume: self._resume_checkpoint(resume)
class BaseTrainer: def __init__( self, model, loss, resume, config, train_loader, val_loader=None, train_logger=None, ): self.model = model self.loss = loss self.config = config self.train_loader = train_loader self.val_loader = val_loader self.train_logger = train_logger self.logger = logging.getLogger(self.__class__.__name__) self.do_validation = self.config["trainer"]["val"] self.start_epoch = 1 self.improved = False # SETTING THE DEVICE self.device, availble_gpus = self._get_available_devices( self.config["n_gpu"]) if config["use_synch_bn"]: self.model = convert_model(self.model) self.model = DataParallelWithCallback(self.model, device_ids=availble_gpus) else: self.model = torch.nn.DataParallel(self.model, device_ids=availble_gpus) self.model.to(self.device) # CONFIGS cfg_trainer = self.config["trainer"] self.epochs = cfg_trainer["epochs"] self.save_period = cfg_trainer["save_period"] # OPTIMIZER if self.config["optimizer"]["differential_lr"]: if isinstance(self.model, torch.nn.DataParallel): trainable_params = [ { "params": filter( lambda p: p.requires_grad, self.model.module.get_decoder_params(), ) }, { "params": filter( lambda p: p.requires_grad, self.model.module.get_backbone_params(), ), "lr": config["optimizer"]["args"]["lr"] / 10, }, ] else: trainable_params = [ { "params": filter(lambda p: p.requires_grad, self.model.get_decoder_params()) }, { "params": filter(lambda p: p.requires_grad, self.model.get_backbone_params()), "lr": config["optimizer"]["args"]["lr"] / 10, }, ] else: trainable_params = filter(lambda p: p.requires_grad, self.model.parameters()) self.optimizer = get_instance(torch.optim, "optimizer", config, trainable_params) self.lr_scheduler = getattr(utils.lr_scheduler, config["lr_scheduler"]["type"])( self.optimizer, self.epochs, len(train_loader)) # MONITORING self.monitor = cfg_trainer.get("monitor", "off") if self.monitor == "off": self.mnt_mode = "off" self.mnt_best = 0 else: self.mnt_mode, self.mnt_metric = self.monitor.split() assert self.mnt_mode in ["min", "max"] self.mnt_best = -math.inf if self.mnt_mode == "max" else math.inf self.early_stoping = cfg_trainer.get("early_stop", math.inf) # CHECKPOINTS & TENSOBOARD start_time = datetime.datetime.now().strftime("%m-%d_%H-%M") self.checkpoint_dir = os.path.join(cfg_trainer["save_dir"], self.config["name"], start_time) helpers.dir_exists(self.checkpoint_dir) config_save_path = os.path.join(self.checkpoint_dir, "config.json") with open(config_save_path, "w") as handle: json.dump(self.config, handle, indent=4, sort_keys=True) writer_dir = os.path.join(cfg_trainer["log_dir"], self.config["name"], start_time) self.writer = tensorboard.SummaryWriter(writer_dir) if resume: self._resume_checkpoint(resume) def _get_available_devices(self, n_gpu): sys_gpu = torch.cuda.device_count() if sys_gpu == 0: self.logger.warning("No GPUs detected, using the CPU") n_gpu = 0 elif n_gpu > sys_gpu: self.logger.warning( f"Nbr of GPU requested is {n_gpu} but only {sys_gpu} are available" ) n_gpu = sys_gpu device = torch.device("cuda:0" if n_gpu > 0 else "cpu") self.logger.info(f"Detected GPUs: {sys_gpu} Requested: {n_gpu}") available_gpus = list(range(n_gpu)) return device, available_gpus def train(self): for epoch in range(self.start_epoch, self.epochs + 1): # RUN TRAIN (AND VAL) results = self._train_epoch(epoch) if (self.do_validation and epoch % self.config["trainer"]["val_per_epochs"] == 0): results = self._valid_epoch(epoch) # LOGGING INFO self.logger.info(f"\n ## Info for epoch {epoch} ## ") for k, v in results.items(): self.logger.info(f" {str(k):15s}: {v}") if self.train_logger is not None: log = {"epoch": epoch, **results} self.train_logger.add_entry(log) # CHECKING IF THIS IS THE BEST MODEL (ONLY FOR VAL) if (self.mnt_mode != "off" and epoch % self.config["trainer"]["val_per_epochs"] == 0): try: if self.mnt_mode == "min": self.improved = log[self.mnt_metric] < self.mnt_best else: self.improved = log[self.mnt_metric] > self.mnt_best except KeyError: self.logger.warning( f"The metrics being tracked ({self.mnt_metric}) has not been calculated. Training stops." ) break if self.improved: self.mnt_best = log[self.mnt_metric] self.not_improved_count = 0 else: self.not_improved_count += 1 if self.not_improved_count > self.early_stoping: self.logger.info( f"\nPerformance didn't improve for {self.early_stoping} epochs" ) self.logger.warning("Training Stoped") break # SAVE CHECKPOINT if epoch % self.save_period == 0: self._save_checkpoint(epoch, save_best=self.improved) def _save_checkpoint(self, epoch, save_best=False): state = { "arch": type(self.model).__name__, "epoch": epoch, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "monitor_best": self.mnt_best, "config": self.config, } filename = os.path.join(self.checkpoint_dir, f"checkpoint-epoch{epoch}.pth") self.logger.info(f"\nSaving a checkpoint: {filename} ...") torch.save(state, filename) if save_best: filename = os.path.join(self.checkpoint_dir, "best_model.pth") torch.save(state, filename) self.logger.info("Saving current best: best_model.pth") def _resume_checkpoint(self, resume_path): self.logger.info(f"Loading checkpoint : {resume_path}") checkpoint = torch.load(resume_path) # Load last run info, the model params, the optimizer and the loggers self.start_epoch = checkpoint["epoch"] + 1 self.mnt_best = checkpoint["monitor_best"] self.not_improved_count = 0 if checkpoint["config"]["arch"] != self.config["arch"]: self.logger.warning({ "Warning! Current model is not the same as the one in the checkpoint" }) self.model.load_state_dict(checkpoint["state_dict"]) if (checkpoint["config"]["optimizer"]["type"] != self.config["optimizer"]["type"]): self.logger.warning({ "Warning! Current optimizer is not the same as the one in the checkpoint" }) self.optimizer.load_state_dict(checkpoint["optimizer"]) self.logger.info( f"Checkpoint <{resume_path}> (epoch {self.start_epoch}) was loaded" ) def _train_epoch(self, epoch): raise NotImplementedError def _valid_epoch(self, epoch): raise NotImplementedError def _eval_metrics(self, output, target): raise NotImplementedError