class OptimizerWrapper(object):
    def __init__(self,
                 model,
                 optimizer_class,
                 optimizer_params,
                 scheduler_class,
                 scheduler_params,
                 clip_grad=None,
                 optimizer_state_dict=None,
                 use_shadow_weights=False,
                 writer=None,
                 experiment_name=None,
                 lookahead=False,
                 lookahead_params=None):
        if use_shadow_weights:
            model = ModuleFloatShadow(model)
            self._original_parameters = list(model.original_parameters())

        self.parameters = list(
            [p for p in model.parameters() if p.requires_grad])
        if lookahead:
            self.base_optimizer = optimizer_class(self.parameters,
                                                  **optimizer_params)
            self.optimizer = Lookahead(self.base_optimizer, **lookahead_params)
        else:
            self.optimizer = optimizer_class(self.parameters,
                                             **optimizer_params)
        self.lookahead = lookahead
        if optimizer_state_dict is not None:
            self.load_state_dict(optimizer_state_dict)
        self.scheduler = scheduler_class(self.optimizer, **scheduler_params)
        self.use_shadow_weights = use_shadow_weights
        self.clip_grad = clip_grad if clip_grad is not None else 0
        self.writer = writer
        self.experiment_name = experiment_name
        self.it = 0

    def state_dict(self):
        """Returns the state of the optimizer as a :class:`dict`.
        """
        return self.optimizer.optimizer.state_dict(
        ) if self.lookahead else self.optimizer.state_dict()

    def load_state_dict(self, state_dict):
        """Loads the optimizer state.

        Arguments:
            state_dict (dict): optimizer state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        # deepcopy, to be consistent with module API
        optimizer_state_dict = state_dict['state']
        if self.lookahead:
            self.optimizer.optimizer.__setstate__(optimizer_state_dict)
        else:
            self.optimizer.__setstate__(optimizer_state_dict)

    def zero_grad(self):
        """Clears the gradients of all optimized :class:`Variable` s."""
        self.optimizer.zero_grad()
        if self.use_shadow_weights:
            for p in self._original_parameters:
                if p.grad is not None:
                    p.grad.detach().zero_()

    def optimizer_step(self, closure=None):
        """Performs a single optimization step (parameter update).

        Arguments:
            closure (callable): A closure that reevaluates the model and
                returns the loss. Optional for most optimizers.
        """
        if self.clip_grad > 1e-12:
            if self.use_shadow_weights:
                torch.nn.utils.clip_grad_norm_(self._original_parameters,
                                               self.clip_grad)
            else:
                torch.nn.utils.clip_grad_norm_(self.parameters, self.clip_grad)
        if self.use_shadow_weights:
            copy_params_grad(self.parameters, self._original_parameters)
        self.optimizer.step(closure)
        if self.use_shadow_weights:
            copy_params(self._original_parameters, self.parameters)

    def scheduler_step(self, epoch=None):
        """Performs a single lr update step.
        """
        self.scheduler.step()

    def batch_step(self, closure=None):
        self.optimizer_step(closure)
        if self.writer is not None:
            self.writer.add_scalars(
                'params/lr',
                {self.experiment_name: self.optimizer.param_groups[0]['lr']},
                self.it)
            self.it += 1
        if is_batch_updating(self.scheduler):
            self.scheduler_step()

    def epoch_step(self):
        if not is_batch_updating(self.scheduler):
            self.scheduler_step()
Exemplo n.º 2
0
class TrainModule(object):
    def __init__(self, dataset, num_classes, model, decoder, down_ratio):
        torch.manual_seed(317)
        self.dataset = dataset
        self.dataset_phase = {
            'dota': ['train', 'valid'],
            'hrsc': ['train', 'test']
        }
        self.num_classes = num_classes
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        self.model = model
        self.decoder = decoder
        self.down_ratio = down_ratio

    def save_model(self, path, epoch, model, optimizer):
        if isinstance(model, torch.nn.DataParallel):
            state_dict = model.module.state_dict()
        else:
            state_dict = model.state_dict()
        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': state_dict,
                'optimizer_state_dict': optimizer.state_dict(),
                # 'loss': loss
            },
            path)

    def load_model(self, model, optimizer, resume, strict=True):
        checkpoint = torch.load(resume,
                                map_location=lambda storage, loc: storage)
        print('loaded weights from {}, epoch {}'.format(
            resume, checkpoint['epoch']))
        state_dict = checkpoint['model_state_dict']
        model.load_state_dict(state_dict, strict=False)
        return model

    def train_network(self, args):
        optimizer = torch.optim.AdamW(self.model.parameters(), args.init_lr)
        self.optimizer = Lookahead(optimizer)
        milestones = [5 + x * 80 for x in range(5)]
        # print(f'milestones:{milestones}')
        # self.scheduler  = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.96, last_epoch=-1)
        scheduler_c = CyclicCosAnnealingLR(optimizer,
                                           milestones=milestones,
                                           eta_min=5e-5)
        self.scheduler = LearningRateWarmUP(optimizer=optimizer,
                                            target_iteration=5,
                                            target_lr=0.003,
                                            after_scheduler=scheduler_c)

        save_path = 'weights_' + args.dataset
        start_epoch = 1
        best_loss = 1000
        # try:
        #     self.model, _, _ = self.load_model(self.model, self.optimizer, args.resume)
        # except:
        #     print('load pretrained model failed')

        # self.model = self.load_model(self.model, self.optimizer, args.resume)

        if not os.path.exists(save_path):
            os.mkdir(save_path)
        if args.ngpus > 1:
            if torch.cuda.device_count() > 1:
                print("Let's use", torch.cuda.device_count(), "GPUs!")
                # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
                self.model = nn.DataParallel(self.model)
        self.model.to(self.device)

        criterion = loss.LossAll()
        print('Setting up data...')

        dataset_module = self.dataset[args.dataset]

        dsets = {
            x: dataset_module(data_dir=args.data_dir,
                              phase=x,
                              input_h=args.input_h,
                              input_w=args.input_w,
                              down_ratio=self.down_ratio)
            for x in self.dataset_phase[args.dataset]
        }

        dsets_loader = {}
        dsets_loader['train'] = torch.utils.data.DataLoader(
            dsets['train'],
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True,
            drop_last=True,
            collate_fn=collater)
        dsets_loader['valid'] = torch.utils.data.DataLoader(
            dsets['valid'],
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True,
            drop_last=True,
            collate_fn=collater)
        print('Starting training...')
        train_loss = []
        valid_loss = []
        ap_list = []
        for epoch in range(start_epoch, args.num_epoch + 1):
            print('-' * 10)
            print('Epoch: {}/{} '.format(epoch, args.num_epoch))
            epoch_loss = self.run_epoch(phase='train',
                                        data_loader=dsets_loader['train'],
                                        criterion=criterion)
            train_loss.append(epoch_loss)
            epoch_loss = self.run_epoch(phase='valid',
                                        data_loader=dsets_loader['valid'],
                                        criterion=criterion)
            valid_loss.append(epoch_loss)

            self.scheduler.step(epoch)

            np.savetxt(os.path.join(save_path, 'train_loss.txt'),
                       train_loss,
                       fmt='%.6f')
            np.savetxt(os.path.join(save_path, 'valid_loss.txt'),
                       valid_loss,
                       fmt='%.6f')
            # if epoch % 5 == 0 or epoch > 20:
            #     self.save_model(os.path.join(save_path, 'model_{}.pth'.format(epoch)),
            #                     epoch,
            #                     self.model,
            #                     self.optimizer)

            if epoch_loss < best_loss:
                self.save_model(
                    os.path.join(save_path, 'model_{}.pth'.format(epoch)),
                    epoch, self.model, self.optimizer)
                print(f'find optimal model, {best_loss}==>{epoch_loss}')
                best_loss = epoch_loss

            self.save_model(os.path.join(save_path, 'model_last.pth'), epoch,
                            self.model, self.optimizer)

    def run_epoch(self, phase, data_loader, criterion):
        if phase == 'train':
            self.model.train()
        else:
            self.model.eval()
        running_loss = 0.
        for data_dict in tqdm.tqdm(data_loader):
            for name in data_dict:
                data_dict[name] = data_dict[name].to(device=self.device,
                                                     non_blocking=True)
            if phase == 'train':
                self.optimizer.zero_grad()
                with torch.enable_grad():

                    pr_decs = self.model(data_dict['input'])
                    loss = criterion(pr_decs, data_dict)
                    loss.backward()
                    self.optimizer.step()

            else:
                with torch.no_grad():
                    pr_decs = self.model(data_dict['input'])
                    loss = criterion(pr_decs, data_dict)

            running_loss += loss.item()
        epoch_loss = running_loss / len(data_loader)
        print('{} loss: {}'.format(phase, epoch_loss))
        return epoch_loss