class OptimizerWrapper(object): def __init__(self, model, optimizer_class, optimizer_params, scheduler_class, scheduler_params, clip_grad=None, optimizer_state_dict=None, use_shadow_weights=False, writer=None, experiment_name=None, lookahead=False, lookahead_params=None): if use_shadow_weights: model = ModuleFloatShadow(model) self._original_parameters = list(model.original_parameters()) self.parameters = list( [p for p in model.parameters() if p.requires_grad]) if lookahead: self.base_optimizer = optimizer_class(self.parameters, **optimizer_params) self.optimizer = Lookahead(self.base_optimizer, **lookahead_params) else: self.optimizer = optimizer_class(self.parameters, **optimizer_params) self.lookahead = lookahead if optimizer_state_dict is not None: self.load_state_dict(optimizer_state_dict) self.scheduler = scheduler_class(self.optimizer, **scheduler_params) self.use_shadow_weights = use_shadow_weights self.clip_grad = clip_grad if clip_grad is not None else 0 self.writer = writer self.experiment_name = experiment_name self.it = 0 def state_dict(self): """Returns the state of the optimizer as a :class:`dict`. """ return self.optimizer.optimizer.state_dict( ) if self.lookahead else self.optimizer.state_dict() def load_state_dict(self, state_dict): """Loads the optimizer state. Arguments: state_dict (dict): optimizer state. Should be an object returned from a call to :meth:`state_dict`. """ # deepcopy, to be consistent with module API optimizer_state_dict = state_dict['state'] if self.lookahead: self.optimizer.optimizer.__setstate__(optimizer_state_dict) else: self.optimizer.__setstate__(optimizer_state_dict) def zero_grad(self): """Clears the gradients of all optimized :class:`Variable` s.""" self.optimizer.zero_grad() if self.use_shadow_weights: for p in self._original_parameters: if p.grad is not None: p.grad.detach().zero_() def optimizer_step(self, closure=None): """Performs a single optimization step (parameter update). Arguments: closure (callable): A closure that reevaluates the model and returns the loss. Optional for most optimizers. """ if self.clip_grad > 1e-12: if self.use_shadow_weights: torch.nn.utils.clip_grad_norm_(self._original_parameters, self.clip_grad) else: torch.nn.utils.clip_grad_norm_(self.parameters, self.clip_grad) if self.use_shadow_weights: copy_params_grad(self.parameters, self._original_parameters) self.optimizer.step(closure) if self.use_shadow_weights: copy_params(self._original_parameters, self.parameters) def scheduler_step(self, epoch=None): """Performs a single lr update step. """ self.scheduler.step() def batch_step(self, closure=None): self.optimizer_step(closure) if self.writer is not None: self.writer.add_scalars( 'params/lr', {self.experiment_name: self.optimizer.param_groups[0]['lr']}, self.it) self.it += 1 if is_batch_updating(self.scheduler): self.scheduler_step() def epoch_step(self): if not is_batch_updating(self.scheduler): self.scheduler_step()
class TrainModule(object): def __init__(self, dataset, num_classes, model, decoder, down_ratio): torch.manual_seed(317) self.dataset = dataset self.dataset_phase = { 'dota': ['train', 'valid'], 'hrsc': ['train', 'test'] } self.num_classes = num_classes self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") self.model = model self.decoder = decoder self.down_ratio = down_ratio def save_model(self, path, epoch, model, optimizer): if isinstance(model, torch.nn.DataParallel): state_dict = model.module.state_dict() else: state_dict = model.state_dict() torch.save( { 'epoch': epoch, 'model_state_dict': state_dict, 'optimizer_state_dict': optimizer.state_dict(), # 'loss': loss }, path) def load_model(self, model, optimizer, resume, strict=True): checkpoint = torch.load(resume, map_location=lambda storage, loc: storage) print('loaded weights from {}, epoch {}'.format( resume, checkpoint['epoch'])) state_dict = checkpoint['model_state_dict'] model.load_state_dict(state_dict, strict=False) return model def train_network(self, args): optimizer = torch.optim.AdamW(self.model.parameters(), args.init_lr) self.optimizer = Lookahead(optimizer) milestones = [5 + x * 80 for x in range(5)] # print(f'milestones:{milestones}') # self.scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.96, last_epoch=-1) scheduler_c = CyclicCosAnnealingLR(optimizer, milestones=milestones, eta_min=5e-5) self.scheduler = LearningRateWarmUP(optimizer=optimizer, target_iteration=5, target_lr=0.003, after_scheduler=scheduler_c) save_path = 'weights_' + args.dataset start_epoch = 1 best_loss = 1000 # try: # self.model, _, _ = self.load_model(self.model, self.optimizer, args.resume) # except: # print('load pretrained model failed') # self.model = self.load_model(self.model, self.optimizer, args.resume) if not os.path.exists(save_path): os.mkdir(save_path) if args.ngpus > 1: if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs self.model = nn.DataParallel(self.model) self.model.to(self.device) criterion = loss.LossAll() print('Setting up data...') dataset_module = self.dataset[args.dataset] dsets = { x: dataset_module(data_dir=args.data_dir, phase=x, input_h=args.input_h, input_w=args.input_w, down_ratio=self.down_ratio) for x in self.dataset_phase[args.dataset] } dsets_loader = {} dsets_loader['train'] = torch.utils.data.DataLoader( dsets['train'], batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True, collate_fn=collater) dsets_loader['valid'] = torch.utils.data.DataLoader( dsets['valid'], batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True, collate_fn=collater) print('Starting training...') train_loss = [] valid_loss = [] ap_list = [] for epoch in range(start_epoch, args.num_epoch + 1): print('-' * 10) print('Epoch: {}/{} '.format(epoch, args.num_epoch)) epoch_loss = self.run_epoch(phase='train', data_loader=dsets_loader['train'], criterion=criterion) train_loss.append(epoch_loss) epoch_loss = self.run_epoch(phase='valid', data_loader=dsets_loader['valid'], criterion=criterion) valid_loss.append(epoch_loss) self.scheduler.step(epoch) np.savetxt(os.path.join(save_path, 'train_loss.txt'), train_loss, fmt='%.6f') np.savetxt(os.path.join(save_path, 'valid_loss.txt'), valid_loss, fmt='%.6f') # if epoch % 5 == 0 or epoch > 20: # self.save_model(os.path.join(save_path, 'model_{}.pth'.format(epoch)), # epoch, # self.model, # self.optimizer) if epoch_loss < best_loss: self.save_model( os.path.join(save_path, 'model_{}.pth'.format(epoch)), epoch, self.model, self.optimizer) print(f'find optimal model, {best_loss}==>{epoch_loss}') best_loss = epoch_loss self.save_model(os.path.join(save_path, 'model_last.pth'), epoch, self.model, self.optimizer) def run_epoch(self, phase, data_loader, criterion): if phase == 'train': self.model.train() else: self.model.eval() running_loss = 0. for data_dict in tqdm.tqdm(data_loader): for name in data_dict: data_dict[name] = data_dict[name].to(device=self.device, non_blocking=True) if phase == 'train': self.optimizer.zero_grad() with torch.enable_grad(): pr_decs = self.model(data_dict['input']) loss = criterion(pr_decs, data_dict) loss.backward() self.optimizer.step() else: with torch.no_grad(): pr_decs = self.model(data_dict['input']) loss = criterion(pr_decs, data_dict) running_loss += loss.item() epoch_loss = running_loss / len(data_loader) print('{} loss: {}'.format(phase, epoch_loss)) return epoch_loss