def __init__(self,
                 model,
                 optimizer_class,
                 optimizer_params,
                 scheduler_class,
                 scheduler_params,
                 clip_grad=None,
                 optimizer_state_dict=None,
                 use_shadow_weights=False,
                 writer=None,
                 experiment_name=None,
                 lookahead=False,
                 lookahead_params=None):
        if use_shadow_weights:
            model = ModuleFloatShadow(model)
            self._original_parameters = list(model.original_parameters())

        self.parameters = list(
            [p for p in model.parameters() if p.requires_grad])
        if lookahead:
            self.base_optimizer = optimizer_class(self.parameters,
                                                  **optimizer_params)
            self.optimizer = Lookahead(self.base_optimizer, **lookahead_params)
        else:
            self.optimizer = optimizer_class(self.parameters,
                                             **optimizer_params)
        self.lookahead = lookahead
        if optimizer_state_dict is not None:
            self.load_state_dict(optimizer_state_dict)
        self.scheduler = scheduler_class(self.optimizer, **scheduler_params)
        self.use_shadow_weights = use_shadow_weights
        self.clip_grad = clip_grad if clip_grad is not None else 0
        self.writer = writer
        self.experiment_name = experiment_name
        self.it = 0
示例#2
0
        sum(p.numel() for p in model.parameters() if p.requires_grad)))
    model = model.to(device)

    print('* Creating Dataloaders, batch size = {:d}'.format(bs))
    train_loader, val_loader = get_train_val_loaders(csv_path_train=csv_train,
                                                     csv_path_val=csv_val,
                                                     batch_size=bs,
                                                     mean=mean,
                                                     std=std,
                                                     qualities=True)

    if optimizer_choice == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    elif optimizer_choice == 'look_ahead':
        base_opt = torch.optim.Adam(model.parameters(), lr=lr)
        optimizer = Lookahead(base_opt, k=5, alpha=0.5)  # Initialize Lookahead
    elif optimizer_choice == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    else:
        sys.exit('not a valid optimizer choice')
    if load_checkpoint != 'no' and n_classes == 5:
        optimizer.load_state_dict(optimizer_state_dict)
        for i, param_group in enumerate(optimizer.param_groups):
            param_group['lr'] = lr

    print('* Instantiating base loss function {}, lambda={}, exp={}'.format(
        base_loss,
        str(lambd).rstrip('0'),
        str(exp).rstrip('0')))
    if base_loss == 'no':
        train_crit, val_crit = get_cost_sensitive_criterion(
                else:
                    model.load_state_dict(torch.load(
                        glob(
                            J(F.pretrain_model_file, "**", "last",
                              "model.pth"))[0]),
                                          strict=False)

            F.folder_id = "{}_nfold{}-{}".format(folder_id, F.n_fold, i + 1)

            base_opt = None
            lr_scheduler = None
            if F.cos_lr:
                base_opt = torch.optim.AdamW(lr=F.lr,
                                             params=model.parameters(),
                                             weight_decay=F.weight_decay)
                base_opt = Lookahead(base_opt, k=5, alpha=0.5)
                lr_scheduler = LambdaLR(
                    base_opt,
                    lr_lambda=lambda epoch: cosine_lr(
                        epoch, max_epoch=F.epochs, offset=F.cos_offset))
                # lr_scheduler = CosineAnnealingWarmRestarts(base_opt, T_0=F.T_0, T_mult=1)

            T.train(
                F,
                model,
                dl_tr,
                dl_val,
                forward_batch_fun=forward_batch_fun,
                hold_best_model=False,
                stop_cond=lambda sc: sc['val_score'] > F.val_score_limit,
                optimizer=base_opt,
示例#4
0
    def train_network(self, args):
        optimizer = torch.optim.AdamW(self.model.parameters(), args.init_lr)
        self.optimizer = Lookahead(optimizer)
        milestones = [5 + x * 80 for x in range(5)]
        # print(f'milestones:{milestones}')
        # self.scheduler  = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.96, last_epoch=-1)
        scheduler_c = CyclicCosAnnealingLR(optimizer,
                                           milestones=milestones,
                                           eta_min=5e-5)
        self.scheduler = LearningRateWarmUP(optimizer=optimizer,
                                            target_iteration=5,
                                            target_lr=0.003,
                                            after_scheduler=scheduler_c)

        save_path = 'weights_' + args.dataset
        start_epoch = 1
        best_loss = 1000
        # try:
        #     self.model, _, _ = self.load_model(self.model, self.optimizer, args.resume)
        # except:
        #     print('load pretrained model failed')

        # self.model = self.load_model(self.model, self.optimizer, args.resume)

        if not os.path.exists(save_path):
            os.mkdir(save_path)
        if args.ngpus > 1:
            if torch.cuda.device_count() > 1:
                print("Let's use", torch.cuda.device_count(), "GPUs!")
                # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
                self.model = nn.DataParallel(self.model)
        self.model.to(self.device)

        criterion = loss.LossAll()
        print('Setting up data...')

        dataset_module = self.dataset[args.dataset]

        dsets = {
            x: dataset_module(data_dir=args.data_dir,
                              phase=x,
                              input_h=args.input_h,
                              input_w=args.input_w,
                              down_ratio=self.down_ratio)
            for x in self.dataset_phase[args.dataset]
        }

        dsets_loader = {}
        dsets_loader['train'] = torch.utils.data.DataLoader(
            dsets['train'],
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True,
            drop_last=True,
            collate_fn=collater)
        dsets_loader['valid'] = torch.utils.data.DataLoader(
            dsets['valid'],
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True,
            drop_last=True,
            collate_fn=collater)
        print('Starting training...')
        train_loss = []
        valid_loss = []
        ap_list = []
        for epoch in range(start_epoch, args.num_epoch + 1):
            print('-' * 10)
            print('Epoch: {}/{} '.format(epoch, args.num_epoch))
            epoch_loss = self.run_epoch(phase='train',
                                        data_loader=dsets_loader['train'],
                                        criterion=criterion)
            train_loss.append(epoch_loss)
            epoch_loss = self.run_epoch(phase='valid',
                                        data_loader=dsets_loader['valid'],
                                        criterion=criterion)
            valid_loss.append(epoch_loss)

            self.scheduler.step(epoch)

            np.savetxt(os.path.join(save_path, 'train_loss.txt'),
                       train_loss,
                       fmt='%.6f')
            np.savetxt(os.path.join(save_path, 'valid_loss.txt'),
                       valid_loss,
                       fmt='%.6f')
            # if epoch % 5 == 0 or epoch > 20:
            #     self.save_model(os.path.join(save_path, 'model_{}.pth'.format(epoch)),
            #                     epoch,
            #                     self.model,
            #                     self.optimizer)

            if epoch_loss < best_loss:
                self.save_model(
                    os.path.join(save_path, 'model_{}.pth'.format(epoch)),
                    epoch, self.model, self.optimizer)
                print(f'find optimal model, {best_loss}==>{epoch_loss}')
                best_loss = epoch_loss

            self.save_model(os.path.join(save_path, 'model_last.pth'), epoch,
                            self.model, self.optimizer)