Exemple #1
0
    def __init__(self, configuration, pre_embed=None) :
        configuration = deepcopy(configuration)
        self.configuration = deepcopy(configuration)

        configuration['model']['encoder']['pre_embed'] = pre_embed

        encoder_copy = deepcopy(configuration['model']['encoder'])
        self.Pencoder = Encoder.from_params(Params(configuration['model']['encoder'])).to(device)
        self.Qencoder = Encoder.from_params(Params(encoder_copy)).to(device)

        configuration['model']['decoder']['hidden_size'] = self.Pencoder.output_size
        self.decoder = AttnDecoderQA.from_params(Params(configuration['model']['decoder'])).to(device)

        self.bsize = configuration['training']['bsize']

        self.adversary_multi = AdversaryMulti(self.decoder)

        weight_decay = configuration['training'].get('weight_decay', 1e-5)
        self.params = list(self.Pencoder.parameters()) + list(self.Qencoder.parameters()) + list(self.decoder.parameters())
        self.optim = torch.optim.Adam(self.params, weight_decay=weight_decay, amsgrad=True)
        # self.optim = torch.optim.Adagrad(self.params, lr=0.05, weight_decay=weight_decay)
        self.criterion = nn.CrossEntropyLoss()

        import time
        dirname = configuration['training']['exp_dirname']
        basepath = configuration['training'].get('basepath', 'outputs')
        self.time_str = time.ctime().replace(' ', '_')
        self.dirname = os.path.join(basepath, dirname, self.time_str)

        self.swa_settings = configuration['training']['swa']
        if self.swa_settings[0]:
            self.swa_all_optim = SWA(self.optim)
            self.running_norms = []
Exemple #2
0
def main(args):
    np.random.seed(432)
    torch.random.manual_seed(432)
    try:
        os.makedirs(args.outpath)
    except OSError:
        pass
    experiment_path = utils.get_new_model_path(args.outpath)
    print(experiment_path)
    train_writer = SummaryWriter(os.path.join(experiment_path, 'train_logs'))
    val_writer = SummaryWriter(os.path.join(experiment_path, 'val_logs'))
    scheduler = cyclical_lr(5, 1e-5, 2e-3)
    trainer = train.Trainer(train_writer, val_writer, scheduler=scheduler)

    train_transform = data.build_preprocessing()
    eval_transform = data.build_preprocessing()

    trainds, evalds = data.build_dataset(args.datadir, None)
    trainds.transform = train_transform
    evalds.transform = eval_transform

    model = models.resnet34()
    base_opt = torch.optim.Adam(model.parameters())
    opt = SWA(base_opt, swa_start=30, swa_freq=10)

    trainloader = DataLoader(trainds,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=8,
                             pin_memory=True)
    evalloader = DataLoader(evalds,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=16,
                            pin_memory=True)

    export_path = os.path.join(experiment_path, 'last.pth')

    best_lwlrap = 0

    for epoch in range(args.epochs):
        print('Epoch {} - lr {:.6f}'.format(epoch, scheduler(epoch)))
        trainer.train_epoch(model, opt, trainloader, scheduler(epoch))
        metrics = trainer.eval_epoch(model, evalloader)

        print('Epoch: {} - lwlrap: {:.4f}'.format(epoch, metrics['lwlrap']))

        # save best model
        if metrics['lwlrap'] > best_lwlrap:
            best_lwlrap = metrics['lwlrap']
            torch.save(model.state_dict(), export_path)

    print('Best metrics {:.4f}'.format(best_lwlrap))
    opt.swap_swa_sgd()
Exemple #3
0
 def set_parameters(self, parameters):
     self.parameters = tuple(parameters)
     self.optimizer = self.optimizer_cls(self.parameters,
                                         **self.optimizer_kwargs)
     if self.swa_start is not None:
         from torchcontrib.optim import SWA
         assert self.swa_freq is not None, self.swa_freq
         assert self.swa_lr is not None, self.swa_lr
         self.optimizer = SWA(self.optimizer,
                              swa_start=self.swa_start,
                              swa_freq=self.swa_freq,
                              swa_lr=self.swa_lr)
Exemple #4
0
    def reset_optimizer(self):
        self.base_optimizer = optimizers[self.optimizer_params['type']](self.net.parameters(),
                                                                        **self.optimizer_params[
                                                                            'args'])

        if self.swa_params is not None:
            self.optimizer = SWA(self.base_optimizer, **self.swa_params)
            self.swa = True
            self.averaged_weights = False
        else:
            self.optimizer = self.base_optimizer
            self.swa = False
Exemple #5
0
 def configure_optimizers(self):
     no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
     param_optimizer = self.model.named_parameters()
     optimizer_grouped_parameters = [{
         'params': [
             p for n, p in param_optimizer
             if not any(nd in n for nd in no_decay)
         ],
         'weight_decay':
         0.001
     }, {
         'params':
         [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
         'weight_decay':
         0.0
     }]
     optimizer = AdamW(optimizer_grouped_parameters,
                       lr=self.global_config.lr)
     lr_scheduler = get_linear_schedule_with_warmup(
         optimizer,
         num_warmup_steps=self.global_config.warmup_steps,
         num_training_steps=self.total_steps(),
     )
     if self.global_config.swa:
         optimizer = SWA(optimizer, self.global_config.swa_start,
                         self.global_config.swa_freq,
                         self.global_config.swa_lr)
     return [optimizer], [{"scheduler": lr_scheduler, "interval": "step"}]
Exemple #6
0
    def build_optimizer(self, name: str,
                        model: torch.nn.Module) -> torch.optim.Optimizer:
        """No bias decay:
        Bag of Tricks for Image Classification with Convolutional Neural Networks
        (https://arxiv.org/pdf/1812.01187.pdf)"""
        weight_p, bias_p = [], []
        for p_name, p in model.named_parameters():
            if 'bias' in p_name:
                bias_p += [p]
            else:
                weight_p += [p]
        parameters = [{
            'params': weight_p,
            'weight_decay': self.weight_decay
        }, {
            'params': bias_p,
            'weight_decay': 0
        }]

        if name == 'Adam':
            return torch.optim.Adam(model.parameters(), lr=self.base_lr)
        if name == 'SGD':
            return torch.optim.SGD(model.parameters(), lr=self.base_lr)
        if name == 'SWA':
            """Stochastic Weight Averaging: 
            Averaging Weights Leads to Wider Optima and Better Generalization
            (https://arxiv.org/pdf/1803.05407.pdf)"""
            base_opt = torch.optim.SGD(parameters, lr=self.base_lr)
            return SWA(base_opt, swa_start=10, swa_freq=5, swa_lr=self.base_lr)
Exemple #7
0
    def set_model_optimizer(self):
        """
        Set model optimizer based on user parameter selection

        1) Set SGD or Adam optimizer
        2) Set SWA if set (check you have downloaded the library using: pip install torchcontrib)
        3) Print if: Use ZCA preprocessing (sometimes useful for CIFAR10) or debug mode is on or off 
           (to check the model on the test set without taking decisions based on it -- all decisions are taken based on the validation set)
        """
        if self.args.optimizer == 'sgd':
            prRed('... SGD ...')
            optimizer = torch.optim.SGD(self.model.parameters(),
                                        self.args.lr,
                                        momentum=self.args.momentum,
                                        weight_decay=self.args.weight_decay,
                                        nesterov=self.args.nesterov)
        else:
            prRed('... Adam optimizer ...')
            optimizer = torch.optim.Adam(self.model.parameters(),
                                         lr=self.args.lr)

        if self.args.swa:
            prRed('Using SWA!')
            from torchcontrib.optim import SWA
            optimizer = SWA(optimizer)

        self.model_optimizer = optimizer

        if self.args.use_zca:
            prPurple('*Use ZCA preprocessing*')
        if self.args.debug:
            prPurple('*Debug mode on*')
 def init_SWA(self, optimizer):
     print("Using SWA")
     opt = SWA(
         optimizer,
         swa_start=self.config_dict["iters_per_epoch"] * 5,
         swa_freq=self.config_dict["iters_per_epoch"] * 2,
         swa_lr=self.config_dict["lr"] * 1e-1,
     )
     return opt
Exemple #9
0
    def _set_optimizer_scheduler(self):
        self.log(f'Optimizer and scheduler started to initilized.', direct_out=True)
        def is_backbone(n):
            return 'backbone' in n

        param_optimizer = list(self.model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']

        # use different learning rate for backbone transformer and classifier head
        if self.use_diff_lr:
            backbone_lr, head_lr = self.config.lr*xm.xrt_world_size(), self.config.lr*xm.xrt_world_size()*500
            optimizer_grouped_parameters = [
                # {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
                # {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
                {"params": [p for n, p in param_optimizer if is_backbone(n)], "lr": backbone_lr},
                {"params": [p for n, p in param_optimizer if not is_backbone(n)], "lr": head_lr}
            ]
            self.log(f'Different Learning rate for backbone: {backbone_lr} head:{head_lr}')
        else:
            optimizer_grouped_parameters = [
                {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
                {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
                ]
        
        try:
            self.optimizer = AdamW(optimizer_grouped_parameters, lr=self.config.lr*xm.xrt_world_size())
            # self.optimizer = SGD(optimizer_grouped_parameters, lr=self.config.lr*xm.xrt_world_size(), momentum=0.9)
        except:
            param_g_1 = [p for n, p in param_optimizer if is_backbone(n)]
            param_g_2 = [p for n, p in param_optimizer if not is_backbone(n)]
            param_intersect = list(set(param_g_1) & set(param_g_2))
            self.log(f'intersect: {param_intersect}', direct_out=True)

        if self.use_SWA:
            self.optimizer = SWA(self.optimizer)
        
        if 'num_training_steps' in self.config.scheduler_params:
            num_training_steps = int(self.config.train_lenght / self.config.batch_size / xm.xrt_world_size() * self.config.n_epochs)
            self.log(f'Number of training steps: {num_training_steps}', direct_out=True)
            self.config.scheduler_params['num_training_steps'] = num_training_steps
        
        self.scheduler = self.config.SchedulerClass(self.optimizer, **self.config.scheduler_params)
Exemple #10
0
    def make_optimizer(self, max_steps):

        optimizer = OPTIMIZERS[self.config.train.optimizer]
        optimizer = optimizer(self.parameters(),
                              self.config.train.learning_rate,
                              weight_decay=self.config.train.weight_decay)
        self.optimizer = SWA(optimizer,
                             swa_start=int(0.8 * max_steps),
                             swa_freq=100)
        self.scheduler = make_scheduler(self.config.train.scheduler,
                                        max_steps=max_steps)(optimizer)
 def configure_optimizers(self): 
     optim = next(o for o in dir(torch.optim) if o.lower() == FLAGS.optim.lower()) # "Adam"
     optimizer=getattr(torch.optim, optim)(self.parameters(), lr=FLAGS.learning_rate) # optimizer object
     #optimizer=torch.optim.SGD(self.parameters(), lr=FLAGS.learning_rate,weight_decay=0.00001) 
     optimizer=torch.optim.AdamW(self.parameters(),lr=FLAGS.learning_rate,weight_decay=0.00001)
     if FLAGS.SWA:
         iterations_per_epoch=int(len(train_ds)/FLAGS.batch_size)
         optimizer = SWA(optimizer, swa_start=int(FLAGS.swa_start*iterations_per_epoch), swa_freq=50, swa_lr=FLAGS.learning_rate/10)
     scheduler=torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=FLAGS.lr_milestones, gamma=0.5)
     #scheduler=torch.optim.lr_scheduler.CyclicLR(optimizer,base_lr=FLAGS.learning_rate/2,max_lr=2*FLAGS.learning_rate,step_size_up=2,step_size_down=2,cycle_momentum=False)
     #step_size_up is in epochs! I don't know why the hell
     return [optimizer], [scheduler]
Exemple #12
0
class ParamOptim:
    params: List[torch.Tensor]
    lr: float = 1e-3
    eps: float = 1e-8
    clip_grad: float = None
    optimizer: Optimizer = AdamW

    def __post_init__(self):
        base_opt = self.optimizer(self.params, lr=self.lr, eps=self.eps)
        self.optim = SWA(base_opt)

    def set_lr(self, lr):
        for pg in self.optim.param_groups:
            pg['lr'] = lr
        return lr

    def step(self, loss):
        self.optim.zero_grad()
        loss.backward()
        if self.clip_grad is not None:
            torch.nn.utils.clip_grad_norm_(self.params, self.clip_grad)
        self.optim.step()
        return loss
Exemple #13
0
 def configure_optimizers(self):
     if self.hparams['optimizer_name'] == 'adam':
         opt = torch.optim.Adam(self.parameters(), lr=self.lr)
         return opt
     elif self.hparams['optimizer_name'] == 'rmsprop':
         opt = torch.optim.RMSprop(self.parameters(),
                                   lr=self.hparams['lr'],
                                   momentum=.001)
     elif self.hparams['optimizer_name'] == 'swa':
         opt = torch.optim.Adam(self.parameters(), lr=self.hparams['lr'])
         return SWA(opt,
                    swa_start=100,
                    swa_freq=50,
                    swa_lr=self.hparams['lr'])
Exemple #14
0
 def _create_optimizer(self, sgd):
     optimizer = AdamW(
         self._model.parameters(),
         lr=getattr(sgd, "pytt_lr", sgd.alpha),
         eps=sgd.eps,
         betas=(sgd.b1, sgd.b2),
         weight_decay=getattr(sgd, "pytt_weight_decay", 0.0),
     )
     if getattr(sgd, "pytt_use_swa", False):
         optimizer = SWA(optimizer,
                         swa_start=1,
                         swa_freq=10,
                         swa_lr=sgd.alpha)
     optimizer.zero_grad()
     return optimizer
Exemple #15
0
def make_optimizer(cfg, model):
    params = []
    for key, value in model.named_parameters():
        if not value.requires_grad:
            continue
        lr = cfg.SOLVER.BASE_LR
        weight_decay = cfg.SOLVER.WEIGHT_DECAY
        if "bias" in key:
            lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
            weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
        params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
    if cfg.SOLVER.OPTIMIZER_NAME == 'SGD':
        optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(
            params, momentum=cfg.SOLVER.MOMENTUM)
        # training loop
        optimizer = SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=0.05)

    else:
        optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params)
    return optimizer
Exemple #16
0
def _dist_train(model, dataset, cfg, validate=False):
    # prepare data loaders
    data_loaders = [
        build_dataloader(dataset,
                         cfg.data.imgs_per_gpu,
                         cfg.data.workers_per_gpu,
                         dist=True)
    ]
    # put model on gpus
    model = MMDistributedDataParallel(model.cuda())
    # build runner
    optimizer = build_optimizer(model, cfg.optimizer)
    if cfg.swa is not None:
        optimizer = SWA(optimizer, cfg.swa['swa_start'], cfg.swa['swa_freq'],
                        cfg.swa['swa_lr'])

    runner = Runner(model, batch_processor, optimizer, cfg.work_dir,
                    cfg.log_level)
    # register hooks
    optimizer_config = DistOptimizerHook(**cfg.optimizer_config)
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config)
    runner.register_hook(DistSamplerSeedHook())
    # register eval hooks
    if validate:
        val_dataset_cfg = cfg.data.val
        if isinstance(model.module, RPN):
            # TODO: implement recall hooks for other datasets
            runner.register_hook(CocoDistEvalRecallHook(val_dataset_cfg))
        else:
            dataset_type = getattr(datasets, val_dataset_cfg.type)
            if issubclass(dataset_type, datasets.CocoDataset):
                runner.register_hook(CocoDistEvalmAPHook(val_dataset_cfg))
            else:
                runner.register_hook(DistEvalmAPHook(val_dataset_cfg))

    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
Exemple #17
0
    # Fetch the loss
    from model.loss_functions.mmd_loss import loss_function

    loss_fn = loss_function

    # Load the VAE model
    model = VariationalAutoencoder(
        params).cuda() if params.cuda else VariationalAutoencoder(params)

    if args.swa:
        # Use the Adam optimizer
        base_optimizer = optim.Adam(model.parameters(),
                                    lr=1e-3,
                                    eps=params.eps,
                                    betas=(params.betas[0], params.betas[1]),
                                    weight_decay=params.weight_decay)
        optimizer = SWA(base_optimizer, swa_start=10, swa_freq=5, swa_lr=1e-3)

    else:
        # Use the Adam optimizer
        optimizer = optim.Adam(model.parameters(),
                               lr=params.learning_rate,
                               eps=params.eps,
                               betas=(params.betas[0], params.betas[1]),
                               weight_decay=params.weight_decay)

    # Train the model
    train(model, train_dl, args.dataloader, optimizer, loss_fn, params,
          model_dir, args.swa, args.restore_file)
Exemple #18
0
    def __init__(self, configuration, pre_embed=None):
        configuration = deepcopy(configuration)
        self.configuration = deepcopy(configuration)

        configuration['model']['encoder']['pre_embed'] = pre_embed
        self.encoder = Encoder.from_params(
            Params(configuration['model']['encoder'])).to(device)

        configuration['model']['decoder'][
            'hidden_size'] = self.encoder.output_size
        self.decoder = AttnDecoder.from_params(
            Params(configuration['model']['decoder'])).to(device)

        self.encoder_params = list(self.encoder.parameters())
        self.attn_params = list([
            v for k, v in self.decoder.named_parameters() if 'attention' in k
        ])
        self.decoder_params = list([
            v for k, v in self.decoder.named_parameters()
            if 'attention' not in k
        ])

        self.bsize = configuration['training']['bsize']

        weight_decay = configuration['training'].get('weight_decay', 1e-5)
        self.encoder_optim = torch.optim.Adam(self.encoder_params,
                                              lr=0.001,
                                              weight_decay=weight_decay,
                                              amsgrad=True)
        self.attn_optim = torch.optim.Adam(self.attn_params,
                                           lr=0.001,
                                           weight_decay=0,
                                           amsgrad=True)
        self.decoder_optim = torch.optim.Adam(self.decoder_params,
                                              lr=0.001,
                                              weight_decay=weight_decay,
                                              amsgrad=True)
        self.adversarymulti = AdversaryMulti(decoder=self.decoder)

        self.all_params = self.encoder_params + self.attn_params + self.decoder_params
        self.all_optim = torch.optim.Adam(self.all_params,
                                          lr=0.001,
                                          weight_decay=weight_decay,
                                          amsgrad=True)
        # self.all_optim = adagrad.Adagrad(self.all_params, weight_decay=weight_decay)

        pos_weight = configuration['training'].get('pos_weight', [1.0] *
                                                   self.decoder.output_size)
        self.pos_weight = torch.Tensor(pos_weight).to(device)
        self.criterion = nn.BCEWithLogitsLoss(reduction='none').to(device)
        self.swa_settings = configuration['training']['swa']

        import time
        dirname = configuration['training']['exp_dirname']
        basepath = configuration['training'].get('basepath', 'outputs')
        self.time_str = time.ctime().replace(' ', '_')
        self.dirname = os.path.join(basepath, dirname, self.time_str)

        self.temperature = configuration['training']['temperature']
        self.train_losses = []

        if self.swa_settings[0]:
            # self.attn_optim = SWA(self.attn_optim, swa_start=3, swa_freq=1, swa_lr=0.05)
            # self.decoder_optim = SWA(self.decoder_optim, swa_start=3, swa_freq=1, swa_lr=0.05)
            # self.encoder_optim = SWA(self.encoder_optim, swa_start=3, swa_freq=1, swa_lr=0.05)
            self.swa_all_optim = SWA(self.all_optim)
            self.running_norms = []
Exemple #19
0
def main(args, logger):
    # trn_df = pd.read_csv(f'{MNT_DIR}/inputs/origin/train.csv')
    trn_df = pd.read_pickle(f'{MNT_DIR}/inputs/nes_info/trn_df.pkl')
    trn_df['is_original'] = 1

    gkf = GroupKFold(n_splits=5).split(
        X=trn_df.question_body,
        groups=trn_df.question_body_le,
    )

    histories = {
        'trn_loss': {},
        'val_loss': {},
        'val_metric': {},
        'val_metric_raws': {},
    }
    loaded_fold = -1
    loaded_epoch = -1
    if args.checkpoint:
        histories, loaded_fold, loaded_epoch = load_checkpoint(args.checkpoint)

    fold_best_metrics = []
    fold_best_metrics_raws = []
    for fold, (trn_idx, val_idx) in enumerate(gkf):
        if fold < loaded_fold:
            fold_best_metrics.append(np.max(histories["val_metric"][fold]))
            fold_best_metrics_raws.append(
                histories["val_metric_raws"][fold][np.argmax(
                    histories["val_metric"][fold])])
            continue
        sel_log(
            f' --------------------------- start fold {fold} --------------------------- ',
            logger)
        fold_trn_df = trn_df.iloc[trn_idx]  # .query('is_original == 1')
        fold_trn_df = fold_trn_df.drop(['is_original', 'question_body_le'],
                                       axis=1)
        # use only original row
        fold_val_df = trn_df.iloc[val_idx].query('is_original == 1')
        fold_val_df = fold_val_df.drop(['is_original', 'question_body_le'],
                                       axis=1)
        if args.debug:
            fold_trn_df = fold_trn_df.sample(100, random_state=71)
            fold_val_df = fold_val_df.sample(100, random_state=71)
        temp = pd.Series(
            list(
                itertools.chain.from_iterable(
                    fold_trn_df.question_title.apply(lambda x: x.split(' ')) +
                    fold_trn_df.question_body.apply(lambda x: x.split(' ')) +
                    fold_trn_df.answer.apply(lambda x: x.split(' '))))
        ).value_counts()
        tokens = temp[temp >= 10].index.tolist()
        # tokens = []
        tokens = [
            'CAT_TECHNOLOGY'.casefold(),
            'CAT_STACKOVERFLOW'.casefold(),
            'CAT_CULTURE'.casefold(),
            'CAT_SCIENCE'.casefold(),
            'CAT_LIFE_ARTS'.casefold(),
        ]

        trn_dataset = QUESTDataset(
            df=fold_trn_df,
            mode='train',
            tokens=tokens,
            augment=[],
            tokenizer_type=TOKENIZER_TYPE,
            pretrained_model_name_or_path=TOKENIZER_PRETRAIN,
            do_lower_case=True,
            LABEL_COL=LABEL_COL,
            t_max_len=30,
            q_max_len=239 * 2,
            a_max_len=239 * 0,
            tqa_mode=TQA_MODE,
            TBSEP='[TBSEP]',
            pos_id_type='arange',
            MAX_SEQUENCE_LENGTH=MAX_SEQ_LEN,
            rm_zero=RM_ZERO,
        )
        # update token
        trn_sampler = RandomSampler(data_source=trn_dataset)
        trn_loader = DataLoader(trn_dataset,
                                batch_size=BATCH_SIZE,
                                sampler=trn_sampler,
                                num_workers=os.cpu_count(),
                                worker_init_fn=lambda x: np.random.seed(),
                                drop_last=True,
                                pin_memory=True)
        val_dataset = QUESTDataset(
            df=fold_val_df,
            mode='valid',
            tokens=tokens,
            augment=[],
            tokenizer_type=TOKENIZER_TYPE,
            pretrained_model_name_or_path=TOKENIZER_PRETRAIN,
            do_lower_case=True,
            LABEL_COL=LABEL_COL,
            t_max_len=30,
            q_max_len=239 * 2,
            a_max_len=239 * 0,
            tqa_mode=TQA_MODE,
            TBSEP='[TBSEP]',
            pos_id_type='arange',
            MAX_SEQUENCE_LENGTH=MAX_SEQ_LEN,
            rm_zero=RM_ZERO,
        )
        val_sampler = RandomSampler(data_source=val_dataset)
        val_loader = DataLoader(val_dataset,
                                batch_size=BATCH_SIZE,
                                sampler=val_sampler,
                                num_workers=os.cpu_count(),
                                worker_init_fn=lambda x: np.random.seed(),
                                drop_last=False,
                                pin_memory=True)

        fobj = BCEWithLogitsLoss()
        state_dict = BertModel.from_pretrained(MODEL_PRETRAIN).state_dict()
        model = BertModelForBinaryMultiLabelClassifier(
            num_labels=len(LABEL_COL),
            config_path=MODEL_CONFIG_PATH,
            state_dict=state_dict,
            token_size=len(trn_dataset.tokenizer),
            MAX_SEQUENCE_LENGTH=MAX_SEQ_LEN,
        )
        # optimizer = optim.Adam(model.parameters(), lr=3e-5)
        optimizer = optim.SGD(model.parameters(), lr=1e-1)
        optimizer = SWA(optimizer, swa_start=2, swa_freq=5, swa_lr=1e-1)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                         T_max=MAX_EPOCH,
                                                         eta_min=1e-2)

        # load checkpoint model, optim, scheduler
        if args.checkpoint and fold == loaded_fold:
            load_checkpoint(args.checkpoint, model, optimizer, scheduler)

        for epoch in tqdm(list(range(MAX_EPOCH))):
            if fold <= loaded_fold and epoch <= loaded_epoch:
                continue
            if epoch < 1:
                model.freeze_unfreeze_bert(freeze=True, logger=logger)
            else:
                model.freeze_unfreeze_bert(freeze=False, logger=logger)
            model = DataParallel(model)
            model = model.to(DEVICE)
            trn_loss = train_one_epoch(model, fobj, optimizer, trn_loader,
                                       DEVICE)
            if epoch > 2:
                optimizer.swap_swa_sgd()
                optimizer.bn_update(trn_loader, model)
            val_loss, val_metric, val_metric_raws, val_y_preds, val_y_trues, val_qa_ids = test(
                model, fobj, val_loader, DEVICE, mode='valid')
            if epoch > 2:
                optimizer.swap_swa_sgd()

            scheduler.step()
            if fold in histories['trn_loss']:
                histories['trn_loss'][fold].append(trn_loss)
            else:
                histories['trn_loss'][fold] = [
                    trn_loss,
                ]
            if fold in histories['val_loss']:
                histories['val_loss'][fold].append(val_loss)
            else:
                histories['val_loss'][fold] = [
                    val_loss,
                ]
            if fold in histories['val_metric']:
                histories['val_metric'][fold].append(val_metric)
            else:
                histories['val_metric'][fold] = [
                    val_metric,
                ]
            if fold in histories['val_metric_raws']:
                histories['val_metric_raws'][fold].append(val_metric_raws)
            else:
                histories['val_metric_raws'][fold] = [
                    val_metric_raws,
                ]

            logging_val_metric_raws = ''
            for val_metric_raw in val_metric_raws:
                logging_val_metric_raws += f'{float(val_metric_raw):.4f}, '

            sel_log(
                f'fold : {fold} -- epoch : {epoch} -- '
                f'trn_loss : {float(trn_loss.detach().to("cpu").numpy()):.4f} -- '
                f'val_loss : {float(val_loss.detach().to("cpu").numpy()):.4f} -- '
                f'val_metric : {float(val_metric):.4f} -- '
                f'val_metric_raws : {logging_val_metric_raws}', logger)
            model = model.to('cpu')
            model = model.module
            save_checkpoint(f'{MNT_DIR}/checkpoints/{EXP_ID}/{fold}', model,
                            optimizer, scheduler, histories, val_y_preds,
                            val_y_trues, val_qa_ids, fold, epoch, val_loss,
                            val_metric)
        fold_best_metrics.append(np.max(histories["val_metric"][fold]))
        fold_best_metrics_raws.append(
            histories["val_metric_raws"][fold][np.argmax(
                histories["val_metric"][fold])])
        save_and_clean_for_prediction(f'{MNT_DIR}/checkpoints/{EXP_ID}/{fold}',
                                      trn_dataset.tokenizer,
                                      clean=False)
        del model

    # calc training stats
    fold_best_metric_mean = np.mean(fold_best_metrics)
    fold_best_metric_std = np.std(fold_best_metrics)
    fold_stats = f'{EXP_ID} : {fold_best_metric_mean:.4f} +- {fold_best_metric_std:.4f}'
    sel_log(fold_stats, logger)
    send_line_notification(fold_stats)

    fold_best_metrics_raws_mean = np.mean(fold_best_metrics_raws, axis=0)
    fold_raw_stats = ''
    for metric_stats_raw in fold_best_metrics_raws_mean:
        fold_raw_stats += f'{float(metric_stats_raw):.4f},'
    sel_log(fold_raw_stats, logger)
    send_line_notification(fold_raw_stats)

    sel_log('now saving best checkpoints...', logger)
def get_default_optimizer():
    base_opt = torch.optim.SGD(model.parameters(), momentum=.9, lr=1e-1)
    optimizer = SWA(base_opt, swa_start=10, swa_freq=5, swa_lr=0.05)
    return optimizer
Exemple #21
0
class Model():
    def __init__(self, configuration, pre_embed=None):
        configuration = deepcopy(configuration)
        self.configuration = deepcopy(configuration)

        configuration['model']['encoder']['pre_embed'] = pre_embed
        self.encoder = Encoder.from_params(
            Params(configuration['model']['encoder'])).to(device)

        configuration['model']['decoder'][
            'hidden_size'] = self.encoder.output_size
        self.decoder = AttnDecoder.from_params(
            Params(configuration['model']['decoder'])).to(device)

        self.encoder_params = list(self.encoder.parameters())
        self.attn_params = list([
            v for k, v in self.decoder.named_parameters() if 'attention' in k
        ])
        self.decoder_params = list([
            v for k, v in self.decoder.named_parameters()
            if 'attention' not in k
        ])

        self.bsize = configuration['training']['bsize']

        weight_decay = configuration['training'].get('weight_decay', 1e-5)
        self.encoder_optim = torch.optim.Adam(self.encoder_params,
                                              lr=0.001,
                                              weight_decay=weight_decay,
                                              amsgrad=True)
        self.attn_optim = torch.optim.Adam(self.attn_params,
                                           lr=0.001,
                                           weight_decay=0,
                                           amsgrad=True)
        self.decoder_optim = torch.optim.Adam(self.decoder_params,
                                              lr=0.001,
                                              weight_decay=weight_decay,
                                              amsgrad=True)
        self.adversarymulti = AdversaryMulti(decoder=self.decoder)

        self.all_params = self.encoder_params + self.attn_params + self.decoder_params
        self.all_optim = torch.optim.Adam(self.all_params,
                                          lr=0.001,
                                          weight_decay=weight_decay,
                                          amsgrad=True)
        # self.all_optim = adagrad.Adagrad(self.all_params, weight_decay=weight_decay)

        pos_weight = configuration['training'].get('pos_weight', [1.0] *
                                                   self.decoder.output_size)
        self.pos_weight = torch.Tensor(pos_weight).to(device)
        self.criterion = nn.BCEWithLogitsLoss(reduction='none').to(device)
        self.swa_settings = configuration['training']['swa']

        import time
        dirname = configuration['training']['exp_dirname']
        basepath = configuration['training'].get('basepath', 'outputs')
        self.time_str = time.ctime().replace(' ', '_')
        self.dirname = os.path.join(basepath, dirname, self.time_str)

        self.temperature = configuration['training']['temperature']
        self.train_losses = []

        if self.swa_settings[0]:
            # self.attn_optim = SWA(self.attn_optim, swa_start=3, swa_freq=1, swa_lr=0.05)
            # self.decoder_optim = SWA(self.decoder_optim, swa_start=3, swa_freq=1, swa_lr=0.05)
            # self.encoder_optim = SWA(self.encoder_optim, swa_start=3, swa_freq=1, swa_lr=0.05)
            self.swa_all_optim = SWA(self.all_optim)
            self.running_norms = []

    @classmethod
    def init_from_config(cls, dirname, **kwargs):
        config = json.load(open(dirname + '/config.json', 'r'))
        config.update(kwargs)
        obj = cls(config)
        obj.load_values(dirname)
        return obj

    def get_param_buffer_norms(self):
        for p in self.swa_all_optim.param_groups[0]['params']:
            param_state = self.swa_all_optim.state[p]
            if 'swa_buffer' not in param_state:
                self.swa_all_optim.update_swa()

        norms = []
        # for p in np.array(self.swa_all_optim.param_groups[0]['params'])[[1, 2, 5, 6, 9]]:
        for p in np.array(self.swa_all_optim.param_groups[0]['params'])[[6,
                                                                         9]]:
            param_state = self.swa_all_optim.state[p]
            buf = np.squeeze(param_state['swa_buffer'].cpu().numpy())
            cur_state = np.squeeze(p.data.cpu().numpy())
            norm = np.linalg.norm(buf - cur_state)
            norms.append(norm)
        if self.swa_settings[3] == 2:
            return np.max(norms)
        return np.mean(norms)

    def total_iter_num(self):
        return self.swa_all_optim.param_groups[0]['step_counter']

    def iter_for_swa_update(self, iter_num):
        return iter_num > self.swa_settings[1] \
               and iter_num % self.swa_settings[2] == 0

    def check_and_update_swa(self):
        if self.iter_for_swa_update(self.total_iter_num()):
            cur_step_diff_norm = self.get_param_buffer_norms()
            if self.swa_settings[3] == 0:
                self.swa_all_optim.update_swa()
                return
            if not self.running_norms:
                running_mean_norm = 0
            else:
                running_mean_norm = np.mean(self.running_norms)

            if cur_step_diff_norm > running_mean_norm:
                self.swa_all_optim.update_swa()
                self.running_norms = [cur_step_diff_norm]
            elif cur_step_diff_norm > 0:
                self.running_norms.append(cur_step_diff_norm)

    def train(self, data_in, target_in, train=True):
        sorting_idx = get_sorting_index_with_noise_from_lengths(
            [len(x) for x in data_in], noise_frac=0.1)
        data = [data_in[i] for i in sorting_idx]
        target = [target_in[i] for i in sorting_idx]

        self.encoder.train()
        self.decoder.train()
        bsize = self.bsize
        N = len(data)
        loss_total = 0

        batches = list(range(0, N, bsize))
        batches = shuffle(batches)

        for n in tqdm(batches):
            torch.cuda.empty_cache()
            batch_doc = data[n:n + bsize]
            batch_data = BatchHolder(batch_doc)

            self.encoder(batch_data)
            self.decoder(batch_data)

            batch_target = target[n:n + bsize]
            batch_target = torch.Tensor(batch_target).to(device)

            if len(batch_target.shape) == 1:  #(B, )
                batch_target = batch_target.unsqueeze(-1)  #(B, 1)

            bce_loss = self.criterion(batch_data.predict / self.temperature,
                                      batch_target)
            weight = batch_target * self.pos_weight + (1 - batch_target)
            bce_loss = (bce_loss * weight).mean(1).sum()

            loss = bce_loss
            self.train_losses.append(bce_loss.detach().cpu().numpy() + 0)

            if hasattr(batch_data, 'reg_loss'):
                loss += batch_data.reg_loss

            if train:
                if self.swa_settings[0]:
                    self.check_and_update_swa()

                    self.swa_all_optim.zero_grad()
                    loss.backward()
                    self.swa_all_optim.step()

                else:
                    # self.encoder_optim.zero_grad()
                    # self.decoder_optim.zero_grad()
                    # self.attn_optim.zero_grad()
                    self.all_optim.zero_grad()
                    loss.backward()
                    # self.encoder_optim.step()
                    # self.decoder_optim.step()
                    # self.attn_optim.step()
                    self.all_optim.step()

            loss_total += float(loss.data.cpu().item())
        if self.swa_settings[0] and self.swa_all_optim.param_groups[0][
                'step_counter'] > self.swa_settings[1]:
            print("\nSWA swapping\n")
            # self.attn_optim.swap_swa_sgd()
            # self.encoder_optim.swap_swa_sgd()
            # self.decoder_optim.swap_swa_sgd()
            self.swa_all_optim.swap_swa_sgd()
            self.running_norms = []

        return loss_total * bsize / N

    def predictor(self, inp_text_permutations):

        text_permutations = [
            dataset_vec.map2idxs(x.split()) for x in inp_text_permutations
        ]
        outputs = []
        bsize = 512
        N = len(text_permutations)
        for n in range(0, N, bsize):
            torch.cuda.empty_cache()
            batch_doc = text_permutations[n:n + bsize]
            batch_data = BatchHolder(batch_doc)

            self.encoder(batch_data)
            self.decoder(batch_data)

            batch_data.predict = torch.sigmoid(batch_data.predict)

            pred = batch_data.predict.cpu().data.numpy()
            for i in range(len(pred)):
                if math.isnan(pred[i][0]):
                    pred[i][0] = 0.5
            outputs.extend(pred)

        ret_val = [[output_i[0], 1 - output_i[0]] for output_i in outputs]
        ret_val = np.array(ret_val)

        return ret_val

    def evaluate(self, data):
        self.encoder.eval()
        self.decoder.eval()
        bsize = self.bsize
        N = len(data)

        outputs = []
        attns = []

        for n in tqdm(range(0, N, bsize)):
            torch.cuda.empty_cache()
            batch_doc = data[n:n + bsize]
            batch_data = BatchHolder(batch_doc)

            self.encoder(batch_data)
            self.decoder(batch_data)

            batch_data.predict = torch.sigmoid(batch_data.predict /
                                               self.temperature)
            if self.decoder.use_attention:
                attn = batch_data.attn.cpu().data.numpy()
                attns.append(attn)

            predict = batch_data.predict.cpu().data.numpy()
            outputs.append(predict)

        outputs = [x for y in outputs for x in y]
        if self.decoder.use_attention:
            attns = [x for y in attns for x in y]
        return outputs, attns

    def get_lime_explanations(self, data):
        explanations = []
        explainer = LimeTextExplainer(class_names=["A", "B"])
        for data_i in data:
            sentence = ' '.join(dataset_vec.map2words(data_i))
            exp = explainer.explain_instance(text_instance=sentence,
                                             classifier_fn=self.predictor,
                                             num_features=len(data_i),
                                             num_samples=5000).as_list()
            explanations.append(exp)
        return explanations

    def gradient_mem(self, data):
        self.encoder.train()
        self.decoder.train()
        bsize = self.bsize
        N = len(data)

        grads = {'XxE': [], 'XxE[X]': [], 'H': []}

        for n in tqdm(range(0, N, bsize)):
            torch.cuda.empty_cache()
            batch_doc = data[n:n + bsize]

            grads_xxe = []
            grads_xxex = []
            grads_H = []

            for i in range(self.decoder.output_size):
                batch_data = BatchHolder(batch_doc)
                batch_data.keep_grads = True
                batch_data.detach = True

                self.encoder(batch_data)
                self.decoder(batch_data)

                torch.sigmoid(batch_data.predict[:, i]).sum().backward()
                g = batch_data.embedding.grad
                em = batch_data.embedding
                g1 = (g * em).sum(-1)

                grads_xxex.append(g1.cpu().data.numpy())

                g1 = (g * self.encoder.embedding.weight.sum(0)).sum(-1)
                grads_xxe.append(g1.cpu().data.numpy())

                g1 = batch_data.hidden.grad.sum(-1)
                grads_H.append(g1.cpu().data.numpy())

            grads_xxe = np.array(grads_xxe).swapaxes(0, 1)
            grads_xxex = np.array(grads_xxex).swapaxes(0, 1)
            grads_H = np.array(grads_H).swapaxes(0, 1)

            import ipdb
            ipdb.set_trace()
            grads['XxE'].append(grads_xxe)
            grads['XxE[X]'].append(grads_xxex)
            grads['H'].append(grads_H)

        for k in grads:
            grads[k] = [x for y in grads[k] for x in y]

        return grads

    def remove_and_run(self, data):
        self.encoder.train()
        self.decoder.train()
        bsize = self.bsize
        N = len(data)

        outputs = []

        for n in tqdm(range(0, N, bsize)):
            batch_doc = data[n:n + bsize]
            batch_data = BatchHolder(batch_doc)
            po = np.zeros(
                (batch_data.B, batch_data.maxlen, self.decoder.output_size))

            for i in range(1, batch_data.maxlen - 1):
                batch_data = BatchHolder(batch_doc)

                batch_data.seq = torch.cat(
                    [batch_data.seq[:, :i], batch_data.seq[:, i + 1:]], dim=-1)
                batch_data.lengths = batch_data.lengths - 1
                batch_data.masks = torch.cat(
                    [batch_data.masks[:, :i], batch_data.masks[:, i + 1:]],
                    dim=-1)

                self.encoder(batch_data)
                self.decoder(batch_data)

                po[:, i] = torch.sigmoid(batch_data.predict).cpu().data.numpy()

            outputs.append(po)

        outputs = [x for y in outputs for x in y]

        return outputs

    def permute_attn(self, data, num_perm=100):
        self.encoder.train()
        self.decoder.train()
        bsize = self.bsize
        N = len(data)

        permutations = []

        for n in tqdm(range(0, N, bsize)):
            torch.cuda.empty_cache()
            batch_doc = data[n:n + bsize]
            batch_data = BatchHolder(batch_doc)

            batch_perms = np.zeros(
                (batch_data.B, num_perm, self.decoder.output_size))

            self.encoder(batch_data)
            self.decoder(batch_data)

            for i in range(num_perm):
                batch_data.permute = True
                self.decoder(batch_data)
                output = torch.sigmoid(batch_data.predict)
                batch_perms[:, i] = output.cpu().data.numpy()

            permutations.append(batch_perms)

        permutations = [x for y in permutations for x in y]

        return permutations

    def save_values(self,
                    use_dirname=None,
                    save_model=True,
                    append_to_dir_name=''):
        if use_dirname is not None:
            dirname = use_dirname
        else:
            dirname = self.dirname + append_to_dir_name
            self.last_epch_dirname = dirname
        os.makedirs(dirname, exist_ok=True)
        shutil.copy2(file_name, dirname + '/')
        json.dump(self.configuration, open(dirname + '/config.json', 'w'))

        if save_model:
            torch.save(self.encoder.state_dict(), dirname + '/enc.th')
            torch.save(self.decoder.state_dict(), dirname + '/dec.th')

        return dirname

    def load_values(self, dirname):
        self.encoder.load_state_dict(
            torch.load(dirname + '/enc.th', map_location={'cuda:1': 'cuda:0'}))
        self.decoder.load_state_dict(
            torch.load(dirname + '/dec.th', map_location={'cuda:1': 'cuda:0'}))

    def adversarial_multi(self, data):
        self.encoder.eval()
        self.decoder.eval()

        for p in self.encoder.parameters():
            p.requires_grad = False

        for p in self.decoder.parameters():
            p.requires_grad = False

        bsize = self.bsize
        N = len(data)

        adverse_attn = []
        adverse_output = []

        for n in tqdm(range(0, N, bsize)):
            torch.cuda.empty_cache()
            batch_doc = data[n:n + bsize]
            batch_data = BatchHolder(batch_doc)

            self.encoder(batch_data)
            self.decoder(batch_data)

            self.adversarymulti(batch_data)

            attn_volatile = batch_data.attn_volatile.cpu().data.numpy(
            )  #(B, 10, L)
            predict_volatile = batch_data.predict_volatile.cpu().data.numpy(
            )  #(B, 10, O)

            adverse_attn.append(attn_volatile)
            adverse_output.append(predict_volatile)

        adverse_output = [x for y in adverse_output for x in y]
        adverse_attn = [x for y in adverse_attn for x in y]

        return adverse_output, adverse_attn

    def logodds_attention(self, data, logodds_map: Dict):
        self.encoder.eval()
        self.decoder.eval()

        bsize = self.bsize
        N = len(data)

        adverse_attn = []
        adverse_output = []

        logodds = np.zeros((self.encoder.vocab_size, ))
        for k, v in logodds_map.items():
            if v is not None:
                logodds[k] = abs(v)
            else:
                logodds[k] = float('-inf')
        logodds = torch.Tensor(logodds).to(device)

        for n in tqdm(range(0, N, bsize)):
            torch.cuda.empty_cache()
            batch_doc = data[n:n + bsize]
            batch_data = BatchHolder(batch_doc)

            self.encoder(batch_data)
            self.decoder(batch_data)

            attn = batch_data.attn  #(B, L)
            batch_data.attn_logodds = logodds[batch_data.seq]
            self.decoder.get_output_from_logodds(batch_data)

            attn_volatile = batch_data.attn_volatile.cpu().data.numpy(
            )  #(B, L)
            predict_volatile = torch.sigmoid(
                batch_data.predict_volatile).cpu().data.numpy()  #(B, O)

            adverse_attn.append(attn_volatile)
            adverse_output.append(predict_volatile)

        adverse_output = [x for y in adverse_output for x in y]
        adverse_attn = [x for y in adverse_attn for x in y]

        return adverse_output, adverse_attn

    def logodds_substitution(self, data, top_logodds_words: Dict):
        self.encoder.eval()
        self.decoder.eval()

        bsize = self.bsize
        N = len(data)

        adverse_X = []
        adverse_attn = []
        adverse_output = []

        words_neg = torch.Tensor(
            top_logodds_words[0][0]).long().cuda().unsqueeze(0)
        words_pos = torch.Tensor(
            top_logodds_words[0][1]).long().cuda().unsqueeze(0)

        words_to_select = torch.cat([words_neg, words_pos], dim=0)  #(2, 5)

        for n in tqdm(range(0, N, bsize)):
            torch.cuda.empty_cache()
            batch_doc = data[n:n + bsize]
            batch_data = BatchHolder(batch_doc)

            self.encoder(batch_data)
            self.decoder(batch_data)
            predict_class = (torch.sigmoid(batch_data.predict).squeeze(-1) >
                             0.5) * 1  #(B,)

            attn = batch_data.attn  #(B, L)
            top_val, top_idx = torch.topk(attn, 5, dim=-1)
            subs_words = words_to_select[1 - predict_class.long()]  #(B, 5)

            batch_data.seq.scatter_(1, top_idx, subs_words)

            self.encoder(batch_data)
            self.decoder(batch_data)

            attn_volatile = batch_data.attn.cpu().data.numpy()  #(B, L)
            predict_volatile = torch.sigmoid(
                batch_data.predict).cpu().data.numpy()  #(B, O)
            X_volatile = batch_data.seq.cpu().data.numpy()

            adverse_X.append(X_volatile)
            adverse_attn.append(attn_volatile)
            adverse_output.append(predict_volatile)

        adverse_X = [x for y in adverse_X for x in y]
        adverse_output = [x for y in adverse_output for x in y]
        adverse_attn = [x for y in adverse_attn for x in y]

        return adverse_output, adverse_attn, adverse_X

    def predict(self, batch_data, lengths, masks):
        batch_holder = BatchHolderIndentity(batch_data, lengths, masks)
        self.encoder(batch_holder)
        self.decoder(batch_holder)
        # batch_holder.predict = torch.sigmoid(batch_holder.predict)
        predict = batch_holder.predict
        return predict
class HM:
    def __init__(self):

        if args.train is not None:
            self.train_tuple = get_tuple(args.train,
                                         bs=args.batch_size,
                                         shuffle=True,
                                         drop_last=False)

        if args.valid is not None:
            valid_bsize = 2048 if args.multiGPU else 50
            self.valid_tuple = get_tuple(args.valid,
                                         bs=valid_bsize,
                                         shuffle=False,
                                         drop_last=False)
        else:
            self.valid_tuple = None

        # Select Model, X is default
        if args.model == "X":
            self.model = ModelX(args)
        elif args.model == "V":
            self.model = ModelV(args)
        elif args.model == "U":
            self.model = ModelU(args)
        elif args.model == "D":
            self.model = ModelD(args)
        elif args.model == 'O':
            self.model = ModelO(args)
        else:
            print(args.model, " is not implemented.")

        # Load pre-trained weights from paths
        if args.loadpre is not None:
            self.model.load(args.loadpre)

        # GPU options
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        self.model = self.model.cuda()

        # Losses and optimizer
        self.logsoftmax = nn.LogSoftmax(dim=1)
        self.nllloss = nn.NLLLoss()

        if args.train is not None:
            batch_per_epoch = len(self.train_tuple.loader)
            self.t_total = int(batch_per_epoch * args.epochs // args.acc)
            print("Total Iters: %d" % self.t_total)

        def is_backbone(n):
            if "encoder" in n:
                return True
            elif "embeddings" in n:
                return True
            elif "pooler" in n:
                return True
            print("F: ", n)
            return False

        no_decay = ['bias', 'LayerNorm.weight']

        params = list(self.model.named_parameters())
        if args.reg:
            optimizer_grouped_parameters = [
                {
                    "params": [p for n, p in params if is_backbone(n)],
                    "lr": args.lr
                },
                {
                    "params": [p for n, p in params if not is_backbone(n)],
                    "lr": args.lr * 500
                },
            ]

            for n, p in self.model.named_parameters():
                print(n)

            self.optim = AdamW(optimizer_grouped_parameters, lr=args.lr)
        else:
            optimizer_grouped_parameters = [{
                'params':
                [p for n, p in params if not any(nd in n for nd in no_decay)],
                'weight_decay':
                args.wd
            }, {
                'params':
                [p for n, p in params if any(nd in n for nd in no_decay)],
                'weight_decay':
                0.0
            }]

            self.optim = AdamW(optimizer_grouped_parameters, lr=args.lr)

        if args.train is not None:
            self.scheduler = get_linear_schedule_with_warmup(
                self.optim, self.t_total * 0.1, self.t_total)

        self.output = args.output
        os.makedirs(self.output, exist_ok=True)

        # SWA Method:
        if args.contrib:
            self.optim = SWA(self.optim,
                             swa_start=self.t_total * 0.75,
                             swa_freq=5,
                             swa_lr=args.lr)

        if args.swa:
            self.swa_model = AveragedModel(self.model)
            self.swa_start = self.t_total * 0.75
            self.swa_scheduler = SWALR(self.optim, swa_lr=args.lr)

    def train(self, train_tuple, eval_tuple):

        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)

        print("Batches:", len(loader))

        self.optim.zero_grad()

        best_roc = 0.
        ups = 0

        total_loss = 0.

        for epoch in range(args.epochs):

            if args.reg:
                if args.model != "X":
                    print(self.model.model.layer_weights)

            id2ans = {}
            id2prob = {}

            for i, (ids, feats, boxes, sent,
                    target) in iter_wrapper(enumerate(loader)):

                if ups == args.midsave:
                    self.save("MID")

                self.model.train()

                if args.swa:
                    self.swa_model.train()

                feats, boxes, target = feats.cuda(), boxes.cuda(), target.long(
                ).cuda()

                # Model expects visual feats as tuple of feats & boxes
                logit = self.model(sent, (feats, boxes))

                # Note: LogSoftmax does not change order, hence there should be nothing wrong with taking it as our prediction
                # In fact ROC AUC stays the exact same for logsoftmax / normal softmax, but logsoftmax is better for loss calculation
                # due to stronger penalization & decomplexifying properties (log(a/b) = log(a) - log(b))
                logit = self.logsoftmax(logit)
                score = logit[:, 1]

                if i < 1:
                    print(logit[0, :].detach())

                # Note: This loss is the same as CrossEntropy (We splitted it up in logsoftmax & neg. log likelihood loss)
                loss = self.nllloss(logit.view(-1, 2), target.view(-1))

                # Scaling loss by batch size, as we have batches with different sizes, since we do not "drop_last" & dividing by acc for accumulation
                # Not scaling the loss will worsen performance by ~2abs%
                loss = loss * logit.size(0) / args.acc
                loss.backward()

                total_loss += loss.detach().item()

                # Acts as argmax - extracting the higher score & the corresponding index (0 or 1)
                _, predict = logit.detach().max(1)
                # Getting labels for accuracy
                for qid, l in zip(ids, predict.cpu().numpy()):
                    id2ans[qid] = l
                # Getting probabilities for Roc auc
                for qid, l in zip(ids, score.detach().cpu().numpy()):
                    id2prob[qid] = l

                if (i + 1) % args.acc == 0:

                    nn.utils.clip_grad_norm_(self.model.parameters(),
                                             args.clip)

                    self.optim.step()

                    if (args.swa) and (ups > self.swa_start):
                        self.swa_model.update_parameters(self.model)
                        self.swa_scheduler.step()
                    else:
                        self.scheduler.step()
                    self.optim.zero_grad()

                    ups += 1

                    # Do Validation in between
                    if ups % 250 == 0:

                        log_str = "\nEpoch(U) %d(%d): Train AC %0.2f RA %0.4f LOSS %0.4f\n" % (
                            epoch, ups, evaluator.evaluate(id2ans) * 100,
                            evaluator.roc_auc(id2prob) * 100, total_loss)

                        # Set loss back to 0 after printing it
                        total_loss = 0.

                        if self.valid_tuple is not None:  # Do Validation
                            acc, roc_auc = self.evaluate(eval_tuple)
                            if roc_auc > best_roc:
                                best_roc = roc_auc
                                best_acc = acc
                                # Only save BEST when no midsave is specified to save space
                                #if args.midsave < 0:
                                #    self.save("BEST")

                            log_str += "\nEpoch(U) %d(%d): DEV AC %0.2f RA %0.4f \n" % (
                                epoch, ups, acc * 100., roc_auc * 100)
                            log_str += "Epoch(U) %d(%d): BEST AC %0.2f RA %0.4f \n" % (
                                epoch, ups, best_acc * 100., best_roc * 100.)

                        print(log_str, end='')

                        with open(self.output + "/log.log", 'a') as f:
                            f.write(log_str)
                            f.flush()

        if (epoch + 1) == args.epochs:
            if args.contrib:
                self.optim.swap_swa_sgd()

        self.save("LAST" + args.train)

    def predict(self, eval_tuple: DataTuple, dump=None, out_csv=True):

        dset, loader, evaluator = eval_tuple
        id2ans = {}
        id2prob = {}

        for i, datum_tuple in enumerate(loader):

            ids, feats, boxes, sent = datum_tuple[:4]

            self.model.eval()

            if args.swa:
                self.swa_model.eval()

            with torch.no_grad():

                feats, boxes = feats.cuda(), boxes.cuda()
                logit = self.model(sent, (feats, boxes))

                # Note: LogSoftmax does not change order, hence there should be nothing wrong with taking it as our prediction
                logit = self.logsoftmax(logit)
                score = logit[:, 1]

                if args.swa:
                    logit = self.swa_model(sent, (feats, boxes))
                    logit = self.logsoftmax(logit)

                _, predict = logit.max(1)

                for qid, l in zip(ids, predict.cpu().numpy()):
                    id2ans[qid] = l

                # Getting probas for Roc Auc
                for qid, l in zip(ids, score.cpu().numpy()):
                    id2prob[qid] = l

        if dump is not None:
            if out_csv == True:
                evaluator.dump_csv(id2ans, id2prob, dump)
            else:
                evaluator.dump_result(id2ans, dump)

        return id2ans, id2prob

    def evaluate(self, eval_tuple: DataTuple, dump=None):
        """Evaluate all data in data_tuple."""
        id2ans, id2prob = self.predict(eval_tuple, dump=dump)

        acc = eval_tuple.evaluator.evaluate(id2ans)
        roc_auc = eval_tuple.evaluator.roc_auc(id2prob)

        return acc, roc_auc

    def save(self, name):
        if args.swa:
            torch.save(self.swa_model.state_dict(),
                       os.path.join(self.output, "%s.pth" % name))
        else:
            torch.save(self.model.state_dict(),
                       os.path.join(self.output, "%s.pth" % name))

    def load(self, path):
        print("Load model from %s" % path)

        state_dict = torch.load("%s" % path)
        new_state_dict = {}
        for key, value in state_dict.items():
            # N_averaged is a key in SWA models we cannot load, so we skip it
            if key.startswith("n_averaged"):
                print("n_averaged:", value)
                continue
            # SWA Models will start with module
            if key.startswith("module."):
                new_state_dict[key[len("module."):]] = value
            else:
                new_state_dict[key] = value
        state_dict = new_state_dict
        self.model.load_state_dict(state_dict)
Exemple #23
0
def main(args, dst_folder):
    # best_ac only record the best top1_ac for validation set.
    best_ac = 0.0
    # os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    if args.cuda_dev == 1:
        torch.cuda.set_device(1)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    torch.backends.cudnn.deterministic = True  # fix the GPU to deterministic mode
    torch.manual_seed(args.seed)  # CPU seed
    if device == "cuda":
        torch.cuda.manual_seed_all(args.seed)  # GPU seed

    random.seed(args.seed)  # python seed for image transformation
    np.random.seed(args.seed)

    if args.dataset == 'svhn':
        mean = [x/255 for x in[127.5,127.5,127.5]]
        std = [x/255 for x in[127.5,127.5,127.5]]
    elif args.dataset == 'cifar100':
        mean = [0.5071, 0.4867, 0.4408]
        std = [0.2675, 0.2565, 0.2761]

    if args.DA == "standard":
        transform_train = transforms.Compose([
            transforms.Pad(2, padding_mode='reflect'),
            transforms.RandomCrop(32),
            #transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])

    elif args.DA == "jitter":
        transform_train = transforms.Compose([
            transforms.Pad(2, padding_mode='reflect'),
            transforms.ColorJitter(brightness= 0.4, contrast= 0.4, saturation= 0.4, hue= 0.1),
            transforms.RandomCrop(32),
            #SVHNPolicy(),
            #AutoAugment(),
            #transforms.RandomHorizontalFlip(),
            
            transforms.ToTensor(),
            #Cutout(n_holes=1,length=20),
            transforms.Normalize(mean, std),
        ])
    else:
        print("Wrong value for --DA argument.")


    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    # data loader
    train_loader, test_loader, train_noisy_indexes = data_config(args, transform_train, transform_test,  dst_folder)


    if args.network == "MT_Net":
        print("Loading MT_Net...")
        model = MT_Net(num_classes = args.num_classes, dropRatio = args.dropout).to(device)

    elif args.network == "WRN28_2_wn":
        print("Loading WRN28_2...")
        model = WRN28_2_wn(num_classes = args.num_classes, dropout = args.dropout).to(device)

    elif args.network == "PreactResNet18_WNdrop":
        print("Loading preActResNet18_WNdrop...")
        model = PreactResNet18_WNdrop(drop_val = args.dropout, num_classes = args.num_classes).to(device)


    print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0))

    milestones = args.M

    if args.swa == 'True':
        # to install it:
        # pip3 install torchcontrib
        # git clone https://github.com/pytorch/contrib.git
        # cd contrib
        # sudo python3 setup.py install
        from torchcontrib.optim import SWA
        #base_optimizer = RAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=1e-4)
        base_optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=1e-4)
        optimizer = SWA(base_optimizer, swa_lr=args.swa_lr)

    else:
        #optimizer = RAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=1e-4)
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=1e-4)

    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)



    loss_train_epoch = []
    loss_val_epoch = []
    acc_train_per_epoch = []
    acc_val_per_epoch = []
    new_labels = []


    exp_path = os.path.join('./', 'noise_models_{0}'.format(args.experiment_name), str(args.labeled_samples))
    res_path = os.path.join('./', 'metrics_{0}'.format(args.experiment_name), str(args.labeled_samples))

    if not os.path.isdir(res_path):
        os.makedirs(res_path)

    if not os.path.isdir(exp_path):
        os.makedirs(exp_path)

    cont = 0

    load = False
    save = True

    if args.initial_epoch != 0:
        initial_epoch = args.initial_epoch
        load = True
        save = False

    if args.dataset_type == 'sym_noise_warmUp':
        load = False
        save = True

    if load:
        if args.loss_term == 'Reg_ep':
            train_type = 'C'
        if args.loss_term == 'MixUp_ep':
            train_type = 'M'
        if args.dropout > 0.0:
            train_type = train_type + 'drop' + str(int(10*args.dropout))
        if args.beta == 0.0:
            train_type = train_type + 'noReg'
        path = './checkpoints/warmUp_{6}_{5}_{0}_{1}_{2}_{3}_S{4}.hdf5'.format(initial_epoch, \
                                                                                args.dataset, \
                                                                                args.labeled_samples, \
                                                                                args.network, \
                                                                                args.seed, \
                                                                                args.Mixup_Alpha, \
                                                                                train_type)

        checkpoint = torch.load(path)
        print("Load model in epoch " + str(checkpoint['epoch']))
        print("Path loaded: ", path)
        model.load_state_dict(checkpoint['state_dict'])
        print("Relabeling the unlabeled samples...")
        model.eval()
        initial_rand_relab = args.label_noise
        results = np.zeros((len(train_loader.dataset), 10), dtype=np.float32)

        for images, images_pslab, labels, soft_labels, index in train_loader:

            images = images.to(device)
            labels = labels.to(device)
            soft_labels = soft_labels.to(device)

            outputs = model(images)
            prob, loss = loss_soft_reg_ep(outputs, labels, soft_labels, device, args)
            results[index.detach().numpy().tolist()] = prob.cpu().detach().numpy().tolist()

        train_loader.dataset.update_labels_randRelab(results, train_noisy_indexes, initial_rand_relab)
        print("Start training...")

    for epoch in range(1, args.epoch + 1):
        st = time.time()
        scheduler.step()
        # train for one epoch
        print(args.experiment_name, args.labeled_samples)

        loss_per_epoch, top_5_train_ac, top1_train_acc_original_labels, \
        top1_train_ac, train_time = train_CrossEntropy_partialRelab(\
                                                        args, model, device, \
                                                        train_loader, optimizer, \
                                                        epoch, train_noisy_indexes)


        loss_train_epoch += [loss_per_epoch]

        # test
        if args.validation_exp == "True":
            loss_per_epoch, acc_val_per_epoch_i = validating(args, model, device, test_loader)
        else:
            loss_per_epoch, acc_val_per_epoch_i = testing(args, model, device, test_loader)

        loss_val_epoch += loss_per_epoch
        acc_train_per_epoch += [top1_train_ac]
        acc_val_per_epoch += acc_val_per_epoch_i



        ####################################################################################################
        #############################               SAVING MODELS                ###########################
        ####################################################################################################

        if not os.path.exists('./checkpoints'):
            os.mkdir('./checkpoints')

        if epoch == 1:
            best_acc_val = acc_val_per_epoch_i[-1]
            snapBest = 'best_epoch_%d_valLoss_%.5f_valAcc_%.5f_noise_%d_bestAccVal_%.5f' % (
                epoch, loss_per_epoch[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val)
            torch.save(model.state_dict(), os.path.join(exp_path, snapBest + '.pth'))
            torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapBest + '.pth'))
        else:
            if acc_val_per_epoch_i[-1] > best_acc_val:
                best_acc_val = acc_val_per_epoch_i[-1]

                if cont > 0:
                    try:
                        os.remove(os.path.join(exp_path, 'opt_' + snapBest + '.pth'))
                        os.remove(os.path.join(exp_path, snapBest + '.pth'))
                    except OSError:
                        pass
                snapBest = 'best_epoch_%d_valLoss_%.5f_valAcc_%.5f_noise_%d_bestAccVal_%.5f' % (
                    epoch, loss_per_epoch[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val)
                torch.save(model.state_dict(), os.path.join(exp_path, snapBest + '.pth'))
                torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapBest + '.pth'))

        cont += 1

        if epoch == args.epoch:
            snapLast = 'last_epoch_%d_valLoss_%.5f_valAcc_%.5f_noise_%d_bestValLoss_%.5f' % (
                epoch, loss_per_epoch[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val)
            torch.save(model.state_dict(), os.path.join(exp_path, snapLast + '.pth'))
            torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapLast + '.pth'))


        #### Save models for ensembles:
        if (epoch >= 150) and (epoch%2 == 0) and (args.save_checkpoint == "True"):
            print("Saving model ...")
            out_path = './checkpoints/ENS_{0}_{1}'.format(args.experiment_name, args.labeled_samples)
            if not os.path.exists(out_path):
                os.makedirs(out_path)
            torch.save(model.state_dict(), out_path + "/epoch_{0}.pth".format(epoch))

        ### Saving model to load it again
        # cond = epoch%1 == 0
        if args.dataset_type == 'sym_noise_warmUp':
            if args.loss_term == 'Reg_ep':
                train_type = 'C'
            if args.loss_term == 'MixUp_ep':
                train_type = 'M'
            if args.dropout > 0.0:
                train_type = train_type + 'drop' + str(int(10*args.dropout))
            if args.beta == 0.0:
                train_type = train_type + 'noReg'


            cond = (epoch==args.epoch)
            name = 'warmUp_{1}_{0}'.format(args.Mixup_Alpha, train_type)
            save = True
        else:
            cond = (epoch==args.epoch)
            name = 'warmUp_{1}_{0}'.format(args.Mixup_Alpha, train_type)
            save = True


        if cond and save:
            print("Saving models...")
            path = './checkpoints/{0}_{1}_{2}_{3}_{4}_S{5}.hdf5'.format(name, epoch, args.dataset, args.labeled_samples, args.network, args.seed)

            save_checkpoint({
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer' : optimizer.state_dict(),
                    'loss_train_epoch' : np.asarray(loss_train_epoch),
                    'loss_val_epoch' : np.asarray(loss_val_epoch),
                    'acc_train_per_epoch' : np.asarray(acc_train_per_epoch),
                    'acc_val_per_epoch' : np.asarray(acc_val_per_epoch),
                    'labels': np.asarray(train_loader.dataset.soft_labels)
                }, filename = path)



        ####################################################################################################
        ############################               SAVING METRICS                ###########################
        ####################################################################################################



        # Save losses:
        np.save(res_path + '/' + str(args.labeled_samples) + '_LOSS_epoch_train.npy', np.asarray(loss_train_epoch))
        np.save(res_path + '/' + str(args.labeled_samples) + '_LOSS_epoch_val.npy', np.asarray(loss_val_epoch))

        # save accuracies:
        np.save(res_path + '/' + str(args.labeled_samples) + '_accuracy_per_epoch_train.npy',
                np.asarray(acc_train_per_epoch))
        np.save(res_path + '/' + str(args.labeled_samples) + '_accuracy_per_epoch_val.npy', np.asarray(acc_val_per_epoch))

        # save the new labels
        new_labels.append(train_loader.dataset.labels)
        np.save(res_path + '/' + str(args.labeled_samples) + '_new_labels.npy',
                np.asarray(new_labels))

        #logging.info('Epoch: [{}|{}], train_loss: {:.3f}, top1_train_ac: {:.3f}, top1_val_ac: {:.3f}, train_time: {:.3f}'.format(epoch, args.epoch, loss_per_epoch[-1], top1_train_ac, acc_val_per_epoch_i[-1], time.time() - st))

    # applying swa
    if args.swa == 'True':
        optimizer.swap_swa_sgd()
        optimizer.bn_update(train_loader, model, device)
        if args.validation_exp == "True":
            loss_swa, acc_val_swa = validating(args, model, device, test_loader)
        else:
            loss_swa, acc_val_swa = testing(args, model, device, test_loader)

        snapLast = 'last_epoch_%d_valLoss_%.5f_valAcc_%.5f_noise_%d_bestValLoss_%.5f_swaAcc_%.5f' % (
            epoch, loss_per_epoch[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val, acc_val_swa[0])
        torch.save(model.state_dict(), os.path.join(exp_path, snapLast + '.pth'))
        torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapLast + '.pth'))

    # save_fig(dst_folder)
    print('Best ac:%f' % best_acc_val)
    record_result(dst_folder, best_ac)
def train(model_name, optim='adam'):
    train_dataset = PretrainDataset(output_shape=config['image_resolution'])
    train_loader = DataLoader(train_dataset,
                              batch_size=config['batch_size'],
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True,
                              drop_last=True)

    val_dataset = IDRND_dataset_CV(fold=0,
                                   mode=config['mode'].replace('train', 'val'),
                                   double_loss_mode=True,
                                   output_shape=config['image_resolution'])
    val_loader = DataLoader(val_dataset,
                            batch_size=config['batch_size'],
                            shuffle=True,
                            num_workers=4,
                            drop_last=False)

    if model_name == 'EF':
        model = DoubleLossModelTwoHead(base_model=EfficientNet.from_pretrained(
            'efficientnet-b3')).to(device)
        model.load_state_dict(
            torch.load(
                f"../models_weights/pretrained/{model_name}_{4}_2.0090592697255896_1.0.pth"
            ))
    elif model_name == 'EFGAP':
        model = DoubleLossModelTwoHead(
            base_model=EfficientNetGAP.from_pretrained('efficientnet-b3')).to(
                device)
        model.load_state_dict(
            torch.load(
                f"../models_weights/pretrained/{model_name}_{4}_2.3281182915644134_1.0.pth"
            ))

    criterion = FocalLoss(add_weight=False).to(device)
    criterion4class = CrossEntropyLoss().to(device)

    if optim == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=config['learning_rate'],
                                     weight_decay=config['weight_decay'])
    elif optim == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=config['learning_rate'],
                                    weight_decay=config['weight_decay'],
                                    nesterov=False)
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    momentum=0.9,
                                    lr=config['learning_rate'],
                                    weight_decay=config['weight_decay'],
                                    nesterov=True)

    steps_per_epoch = train_loader.__len__() - 15
    swa = SWA(optimizer,
              swa_start=config['swa_start'] * steps_per_epoch,
              swa_freq=int(config['swa_freq'] * steps_per_epoch),
              swa_lr=config['learning_rate'] / 10)
    scheduler = ExponentialLR(swa, gamma=0.9)
    # scheduler = StepLR(swa, step_size=5*steps_per_epoch, gamma=0.5)

    global_step = 0
    for epoch in trange(10):
        if epoch < 5:
            scheduler.step()
            continue
        model.train()
        train_bar = tqdm(train_loader)
        train_bar.set_description_str(desc=f"N epochs - {epoch}")

        for step, batch in enumerate(train_bar):
            global_step += 1
            image = batch['image'].to(device)
            label4class = batch['label0'].to(device)
            label = batch['label1'].to(device)

            output4class, output = model(image)
            loss4class = criterion4class(output4class, label4class)
            loss = criterion(output.squeeze(), label)
            swa.zero_grad()
            total_loss = loss4class * 0.5 + loss * 0.5
            total_loss.backward()
            swa.step()
            train_writer.add_scalar(tag="learning_rate",
                                    scalar_value=scheduler.get_lr()[0],
                                    global_step=global_step)
            train_writer.add_scalar(tag="BinaryLoss",
                                    scalar_value=loss.item(),
                                    global_step=global_step)
            train_writer.add_scalar(tag="SoftMaxLoss",
                                    scalar_value=loss4class.item(),
                                    global_step=global_step)
            train_bar.set_postfix_str(f"Loss = {loss.item()}")
            try:
                train_writer.add_scalar(tag="idrnd_score",
                                        scalar_value=idrnd_score_pytorch(
                                            label, output),
                                        global_step=global_step)
                train_writer.add_scalar(tag="far_score",
                                        scalar_value=far_score(label, output),
                                        global_step=global_step)
                train_writer.add_scalar(tag="frr_score",
                                        scalar_value=frr_score(label, output),
                                        global_step=global_step)
                train_writer.add_scalar(tag="accuracy",
                                        scalar_value=bce_accuracy(
                                            label, output),
                                        global_step=global_step)
            except Exception:
                pass

        if (epoch > config['swa_start']
                and epoch % 2 == 0) or (epoch == config['number_epochs'] - 1):
            swa.swap_swa_sgd()
            swa.bn_update(train_loader, model, device)
            swa.swap_swa_sgd()

        scheduler.step()
        evaluate(model, val_loader, epoch, model_name)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--fast', action='store_true')
    parser.add_argument('--mixup', action='store_true')
    parser.add_argument('--balance', action='store_true')
    parser.add_argument('--balance-datasets', action='store_true')
    parser.add_argument('--swa', action='store_true')
    parser.add_argument('--show', action='store_true')
    parser.add_argument('--use-idrid', action='store_true')
    parser.add_argument('--use-messidor', action='store_true')
    parser.add_argument('--use-aptos2015', action='store_true')
    parser.add_argument('--use-aptos2019', action='store_true')
    parser.add_argument('-v', '--verbose', action='store_true')
    parser.add_argument('--coarse', action='store_true')
    parser.add_argument('-acc',
                        '--accumulation-steps',
                        type=int,
                        default=1,
                        help='Number of batches to process')
    parser.add_argument('-dd',
                        '--data-dir',
                        type=str,
                        default='data',
                        help='Data directory')
    parser.add_argument('-m',
                        '--model',
                        type=str,
                        default='resnet18_gap',
                        help='')
    parser.add_argument('-b',
                        '--batch-size',
                        type=int,
                        default=8,
                        help='Batch Size during training, e.g. -b 64')
    parser.add_argument('-e',
                        '--epochs',
                        type=int,
                        default=100,
                        help='Epoch to run')
    parser.add_argument('-es',
                        '--early-stopping',
                        type=int,
                        default=None,
                        help='Maximum number of epochs without improvement')
    parser.add_argument('-f',
                        '--fold',
                        action='append',
                        type=int,
                        default=None)
    parser.add_argument('-ft', '--fine-tune', default=0, type=int)
    parser.add_argument('-lr',
                        '--learning-rate',
                        type=float,
                        default=1e-4,
                        help='Initial learning rate')
    parser.add_argument('--criterion-reg',
                        type=str,
                        default=None,
                        nargs='+',
                        help='Criterion')
    parser.add_argument('--criterion-ord',
                        type=str,
                        default=None,
                        nargs='+',
                        help='Criterion')
    parser.add_argument('--criterion-cls',
                        type=str,
                        default=['ce'],
                        nargs='+',
                        help='Criterion')
    parser.add_argument('-l1',
                        type=float,
                        default=0,
                        help='L1 regularization loss')
    parser.add_argument('-l2',
                        type=float,
                        default=0,
                        help='L2 regularization loss')
    parser.add_argument('-o',
                        '--optimizer',
                        default='Adam',
                        help='Name of the optimizer')
    parser.add_argument('-p',
                        '--preprocessing',
                        default=None,
                        help='Preprocessing method')
    parser.add_argument(
        '-c',
        '--checkpoint',
        type=str,
        default=None,
        help='Checkpoint filename to use as initial model weights')
    parser.add_argument('-w',
                        '--workers',
                        default=multiprocessing.cpu_count(),
                        type=int,
                        help='Num workers')
    parser.add_argument('-a',
                        '--augmentations',
                        default='medium',
                        type=str,
                        help='')
    parser.add_argument('-tta',
                        '--tta',
                        default=None,
                        type=str,
                        help='Type of TTA to use [fliplr, d4]')
    parser.add_argument('-t', '--transfer', default=None, type=str, help='')
    parser.add_argument('--fp16', action='store_true')
    parser.add_argument('-s',
                        '--scheduler',
                        default='multistep',
                        type=str,
                        help='')
    parser.add_argument('--size',
                        default=512,
                        type=int,
                        help='Image size for training & inference')
    parser.add_argument('-wd',
                        '--weight-decay',
                        default=0,
                        type=float,
                        help='L2 weight decay')
    parser.add_argument('-wds',
                        '--weight-decay-step',
                        default=None,
                        type=float,
                        help='L2 weight decay step to add after each epoch')
    parser.add_argument('-d',
                        '--dropout',
                        default=0.0,
                        type=float,
                        help='Dropout before head layer')
    parser.add_argument(
        '--warmup',
        default=0,
        type=int,
        help=
        'Number of warmup epochs with 0.1 of the initial LR and frozed encoder'
    )
    parser.add_argument('-x',
                        '--experiment',
                        default=None,
                        type=str,
                        help='Dropout before head layer')

    args = parser.parse_args()

    data_dir = args.data_dir
    num_workers = args.workers
    num_epochs = args.epochs
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    l1 = args.l1
    l2 = args.l2
    early_stopping = args.early_stopping
    model_name = args.model
    optimizer_name = args.optimizer
    image_size = (args.size, args.size)
    fast = args.fast
    augmentations = args.augmentations
    fp16 = args.fp16
    fine_tune = args.fine_tune
    criterion_reg_name = args.criterion_reg
    criterion_cls_name = args.criterion_cls
    criterion_ord_name = args.criterion_ord
    folds = args.fold
    mixup = args.mixup
    balance = args.balance
    balance_datasets = args.balance_datasets
    use_swa = args.swa
    show_batches = args.show
    scheduler_name = args.scheduler
    verbose = args.verbose
    weight_decay = args.weight_decay
    use_idrid = args.use_idrid
    use_messidor = args.use_messidor
    use_aptos2015 = args.use_aptos2015
    use_aptos2019 = args.use_aptos2019
    warmup = args.warmup
    dropout = args.dropout
    use_unsupervised = False
    experiment = args.experiment
    preprocessing = args.preprocessing
    weight_decay_step = args.weight_decay_step
    coarse_grading = args.coarse
    class_names = get_class_names(coarse_grading)

    assert use_aptos2015 or use_aptos2019 or use_idrid or use_messidor

    current_time = datetime.now().strftime('%b%d_%H_%M')
    random_name = get_random_name()

    if folds is None or len(folds) == 0:
        folds = [None]

    for fold in folds:
        torch.cuda.empty_cache()
        checkpoint_prefix = f'{model_name}_{args.size}_{augmentations}'

        if preprocessing is not None:
            checkpoint_prefix += f'_{preprocessing}'
        if use_aptos2019:
            checkpoint_prefix += '_aptos2019'
        if use_aptos2015:
            checkpoint_prefix += '_aptos2015'
        if use_messidor:
            checkpoint_prefix += '_messidor'
        if use_idrid:
            checkpoint_prefix += '_idrid'
        if coarse_grading:
            checkpoint_prefix += '_coarse'

        if fold is not None:
            checkpoint_prefix += f'_fold{fold}'

        checkpoint_prefix += f'_{random_name}'

        if experiment is not None:
            checkpoint_prefix = experiment

        directory_prefix = f'{current_time}/{checkpoint_prefix}'
        log_dir = os.path.join('runs', directory_prefix)
        os.makedirs(log_dir, exist_ok=False)

        config_fname = os.path.join(log_dir, f'{checkpoint_prefix}.json')
        with open(config_fname, 'w') as f:
            train_session_args = vars(args)
            f.write(json.dumps(train_session_args, indent=2))

        set_manual_seed(args.seed)
        num_classes = len(class_names)
        model = get_model(model_name, num_classes=num_classes,
                          dropout=dropout).cuda()

        if args.transfer:
            transfer_checkpoint = fs.auto_file(args.transfer)
            print("Transfering weights from model checkpoint",
                  transfer_checkpoint)
            checkpoint = load_checkpoint(transfer_checkpoint)
            pretrained_dict = checkpoint['model_state_dict']

            for name, value in pretrained_dict.items():
                try:
                    model.load_state_dict(collections.OrderedDict([(name,
                                                                    value)]),
                                          strict=False)
                except Exception as e:
                    print(e)

            report_checkpoint(checkpoint)

        if args.checkpoint:
            checkpoint = load_checkpoint(fs.auto_file(args.checkpoint))
            unpack_checkpoint(checkpoint, model=model)
            report_checkpoint(checkpoint)

        train_ds, valid_ds, train_sizes = get_datasets(
            data_dir=data_dir,
            use_aptos2019=use_aptos2019,
            use_aptos2015=use_aptos2015,
            use_idrid=use_idrid,
            use_messidor=use_messidor,
            use_unsupervised=False,
            coarse_grading=coarse_grading,
            image_size=image_size,
            augmentation=augmentations,
            preprocessing=preprocessing,
            target_dtype=int,
            fold=fold,
            folds=4)

        train_loader, valid_loader = get_dataloaders(
            train_ds,
            valid_ds,
            batch_size=batch_size,
            num_workers=num_workers,
            train_sizes=train_sizes,
            balance=balance,
            balance_datasets=balance_datasets,
            balance_unlabeled=False)

        loaders = collections.OrderedDict()
        loaders["train"] = train_loader
        loaders["valid"] = valid_loader

        print('Datasets         :', data_dir)
        print('  Train size     :', len(train_loader),
              len(train_loader.dataset))
        print('  Valid size     :', len(valid_loader),
              len(valid_loader.dataset))
        print('  Aptos 2019     :', use_aptos2019)
        print('  Aptos 2015     :', use_aptos2015)
        print('  IDRID          :', use_idrid)
        print('  Messidor       :', use_messidor)
        print('Train session    :', directory_prefix)
        print('  FP16 mode      :', fp16)
        print('  Fast mode      :', fast)
        print('  Mixup          :', mixup)
        print('  Balance cls.   :', balance)
        print('  Balance ds.    :', balance_datasets)
        print('  Warmup epoch   :', warmup)
        print('  Train epochs   :', num_epochs)
        print('  Fine-tune ephs :', fine_tune)
        print('  Workers        :', num_workers)
        print('  Fold           :', fold)
        print('  Log dir        :', log_dir)
        print('  Augmentations  :', augmentations)
        print('Model            :', model_name)
        print('  Parameters     :', count_parameters(model))
        print('  Image size     :', image_size)
        print('  Dropout        :', dropout)
        print('  Classes        :', class_names, num_classes)
        print('Optimizer        :', optimizer_name)
        print('  Learning rate  :', learning_rate)
        print('  Batch size     :', batch_size)
        print('  Criterion (cls):', criterion_cls_name)
        print('  Criterion (reg):', criterion_reg_name)
        print('  Criterion (ord):', criterion_ord_name)
        print('  Scheduler      :', scheduler_name)
        print('  Weight decay   :', weight_decay, weight_decay_step)
        print('  L1 reg.        :', l1)
        print('  L2 reg.        :', l2)
        print('  Early stopping :', early_stopping)

        # model training
        callbacks = []
        criterions = {}

        main_metric = 'cls/kappa'
        if criterion_reg_name is not None:
            cb, crits = get_reg_callbacks(criterion_reg_name,
                                          class_names=class_names,
                                          show=show_batches)
            callbacks += cb
            criterions.update(crits)

        if criterion_ord_name is not None:
            cb, crits = get_ord_callbacks(criterion_ord_name,
                                          class_names=class_names,
                                          show=show_batches)
            callbacks += cb
            criterions.update(crits)

        if criterion_cls_name is not None:
            cb, crits = get_cls_callbacks(criterion_cls_name,
                                          num_classes=num_classes,
                                          num_epochs=num_epochs,
                                          class_names=class_names,
                                          show=show_batches)
            callbacks += cb
            criterions.update(crits)

        if l1 > 0:
            callbacks += [
                LPRegularizationCallback(start_wd=l1,
                                         end_wd=l1,
                                         schedule=None,
                                         prefix='l1',
                                         p=1)
            ]

        if l2 > 0:
            callbacks += [
                LPRegularizationCallback(start_wd=l2,
                                         end_wd=l2,
                                         schedule=None,
                                         prefix='l2',
                                         p=2)
            ]

        callbacks += [CustomOptimizerCallback()]

        runner = SupervisedRunner(input_key='image')

        # Pretrain/warmup
        if warmup:
            set_trainable(model.encoder, False, False)
            optimizer = get_optimizer('Adam',
                                      get_optimizable_parameters(model),
                                      learning_rate=learning_rate * 0.1)

            runner.train(fp16=fp16,
                         model=model,
                         criterion=criterions,
                         optimizer=optimizer,
                         scheduler=None,
                         callbacks=callbacks,
                         loaders=loaders,
                         logdir=os.path.join(log_dir, 'warmup'),
                         num_epochs=warmup,
                         verbose=verbose,
                         main_metric=main_metric,
                         minimize_metric=False,
                         checkpoint_data={"cmd_args": vars(args)})

            del optimizer

        # Main train
        if num_epochs:
            set_trainable(model.encoder, True, False)

            optimizer = get_optimizer(optimizer_name,
                                      get_optimizable_parameters(model),
                                      learning_rate=learning_rate,
                                      weight_decay=weight_decay)

            if use_swa:
                from torchcontrib.optim import SWA
                optimizer = SWA(optimizer,
                                swa_start=len(train_loader),
                                swa_freq=512)

            scheduler = get_scheduler(scheduler_name,
                                      optimizer,
                                      lr=learning_rate,
                                      num_epochs=num_epochs,
                                      batches_in_epoch=len(train_loader))

            # Additional callbacks that specific to main stage only added here to copy of callbacks
            main_stage_callbacks = callbacks
            if early_stopping:
                es_callback = EarlyStoppingCallback(early_stopping,
                                                    min_delta=1e-4,
                                                    metric=main_metric,
                                                    minimize=False)
                main_stage_callbacks = callbacks + [es_callback]

            runner.train(fp16=fp16,
                         model=model,
                         criterion=criterions,
                         optimizer=optimizer,
                         scheduler=scheduler,
                         callbacks=main_stage_callbacks,
                         loaders=loaders,
                         logdir=os.path.join(log_dir, 'main'),
                         num_epochs=num_epochs,
                         verbose=verbose,
                         main_metric=main_metric,
                         minimize_metric=False,
                         checkpoint_data={"cmd_args": vars(args)})

            del optimizer, scheduler

            best_checkpoint = os.path.join(log_dir, 'main', 'checkpoints',
                                           'best.pth')
            model_checkpoint = os.path.join(log_dir, 'main', 'checkpoints',
                                            f'{checkpoint_prefix}.pth')
            clean_checkpoint(best_checkpoint, model_checkpoint)

            # Restoring best model from checkpoint
            checkpoint = load_checkpoint(best_checkpoint)
            unpack_checkpoint(checkpoint, model=model)
            report_checkpoint(checkpoint)

        # Stage 3 - Fine tuning
        if fine_tune:
            set_trainable(model.encoder, False, False)
            optimizer = get_optimizer(optimizer_name,
                                      get_optimizable_parameters(model),
                                      learning_rate=learning_rate)
            scheduler = get_scheduler('multistep',
                                      optimizer,
                                      lr=learning_rate,
                                      num_epochs=fine_tune,
                                      batches_in_epoch=len(train_loader))

            runner.train(fp16=fp16,
                         model=model,
                         criterion=criterions,
                         optimizer=optimizer,
                         scheduler=scheduler,
                         callbacks=callbacks,
                         loaders=loaders,
                         logdir=os.path.join(log_dir, 'finetune'),
                         num_epochs=fine_tune,
                         verbose=verbose,
                         main_metric=main_metric,
                         minimize_metric=False,
                         checkpoint_data={"cmd_args": vars(args)})

            best_checkpoint = os.path.join(log_dir, 'finetune', 'checkpoints',
                                           'best.pth')
            model_checkpoint = os.path.join(log_dir, 'finetune', 'checkpoints',
                                            f'{checkpoint_prefix}.pth')
            clean_checkpoint(best_checkpoint, model_checkpoint)
 def __init__(self, config_path):
     self.image_config, self.model_config, self.run_config = LoadConfig(
         config_path=config_path).train_config()
     self.device = torch.device('cuda:%d' %
                                self.run_config['device_ids'][0] if torch.
                                cuda.is_available else 'cpu')
     self.model = getModel(self.model_config)
     os.makedirs(self.run_config['model_save_path'], exist_ok=True)
     self.run_config['num_workers'] = self.run_config['num_workers'] * len(
         self.run_config['device_ids'])
     self.train_set = Data(root=self.image_config['image_path'],
                           phase='train',
                           data_name=self.image_config['data_name'],
                           img_mode=self.image_config['image_mode'],
                           n_classes=self.model_config['num_classes'],
                           size=self.image_config['image_size'],
                           scale=self.image_config['image_scale'])
     self.valid_set = Data(root=self.image_config['image_path'],
                           phase='valid',
                           data_name=self.image_config['data_name'],
                           img_mode=self.image_config['image_mode'],
                           n_classes=self.model_config['num_classes'],
                           size=self.image_config['image_size'],
                           scale=self.image_config['image_scale'])
     self.className = self.valid_set.className
     self.train_loader = DataLoader(
         self.train_set,
         batch_size=self.run_config['batch_size'],
         shuffle=True,
         num_workers=self.run_config['num_workers'],
         pin_memory=True,
         drop_last=False)
     self.valid_loader = DataLoader(
         self.valid_set,
         batch_size=self.run_config['batch_size'],
         shuffle=True,
         num_workers=self.run_config['num_workers'],
         pin_memory=True,
         drop_last=False)
     train_params = self.model.parameters()
     self.optimizer = RAdam(train_params,
                            lr=eval(self.run_config['lr']),
                            weight_decay=eval(
                                self.run_config['weight_decay']))
     if self.run_config['swa']:
         self.optimizer = SWA(self.optimizer,
                              swa_start=10,
                              swa_freq=5,
                              swa_lr=0.005)
     # 设置学习率调节策略
     self.lr_scheduler = utils.adjustLR.AdjustLr(self.optimizer)
     if self.run_config['use_weight_balance']:
         weight = utils.weight_balance.getWeight(
             self.run_config['weights_file'])
     else:
         weight = None
     self.Criterion = SegmentationLosses(weight=weight,
                                         cuda=True,
                                         device=self.device,
                                         batch_average=False)
     self.metric = utils.metrics.MetricMeter(
         self.model_config['num_classes'])
class Trainer():
    def __init__(self, config_path):
        self.image_config, self.model_config, self.run_config = LoadConfig(
            config_path=config_path).train_config()
        self.device = torch.device('cuda:%d' %
                                   self.run_config['device_ids'][0] if torch.
                                   cuda.is_available else 'cpu')
        self.model = getModel(self.model_config)
        os.makedirs(self.run_config['model_save_path'], exist_ok=True)
        self.run_config['num_workers'] = self.run_config['num_workers'] * len(
            self.run_config['device_ids'])
        self.train_set = Data(root=self.image_config['image_path'],
                              phase='train',
                              data_name=self.image_config['data_name'],
                              img_mode=self.image_config['image_mode'],
                              n_classes=self.model_config['num_classes'],
                              size=self.image_config['image_size'],
                              scale=self.image_config['image_scale'])
        self.valid_set = Data(root=self.image_config['image_path'],
                              phase='valid',
                              data_name=self.image_config['data_name'],
                              img_mode=self.image_config['image_mode'],
                              n_classes=self.model_config['num_classes'],
                              size=self.image_config['image_size'],
                              scale=self.image_config['image_scale'])
        self.className = self.valid_set.className
        self.train_loader = DataLoader(
            self.train_set,
            batch_size=self.run_config['batch_size'],
            shuffle=True,
            num_workers=self.run_config['num_workers'],
            pin_memory=True,
            drop_last=False)
        self.valid_loader = DataLoader(
            self.valid_set,
            batch_size=self.run_config['batch_size'],
            shuffle=True,
            num_workers=self.run_config['num_workers'],
            pin_memory=True,
            drop_last=False)
        train_params = self.model.parameters()
        self.optimizer = RAdam(train_params,
                               lr=eval(self.run_config['lr']),
                               weight_decay=eval(
                                   self.run_config['weight_decay']))
        if self.run_config['swa']:
            self.optimizer = SWA(self.optimizer,
                                 swa_start=10,
                                 swa_freq=5,
                                 swa_lr=0.005)
        # 设置学习率调节策略
        self.lr_scheduler = utils.adjustLR.AdjustLr(self.optimizer)
        if self.run_config['use_weight_balance']:
            weight = utils.weight_balance.getWeight(
                self.run_config['weights_file'])
        else:
            weight = None
        self.Criterion = SegmentationLosses(weight=weight,
                                            cuda=True,
                                            device=self.device,
                                            batch_average=False)
        self.metric = utils.metrics.MetricMeter(
            self.model_config['num_classes'])

    @logger.catch  # 在日志中记录错误
    def __call__(self):
        # 设置记录日志
        self.global_name = self.model_config['model_name']
        logger.add(os.path.join(
            self.image_config['image_path'], 'log',
            'log_' + self.global_name + '/train_{time}.log'),
                   format="{time} {level} {message}",
                   level="INFO",
                   encoding='utf-8')
        self.writer = SummaryWriter(logdir=os.path.join(
            self.image_config['image_path'], 'run', 'runs_' +
            self.global_name))
        logger.info("image_config: {} \n model_config: {} \n run_config: {}",
                    self.image_config, self.model_config, self.run_config)
        # 如果多余一张卡,就采用数据并行
        if len(self.run_config['device_ids']) > 1:
            self.model = nn.DataParallel(
                self.model, device_ids=self.run_config['device_ids'])
        self.model.to(device=self.device)
        cnt = 0
        # 如果有预训练模型就加载
        if self.run_config['pretrain'] != '':
            logger.info("loading pretrain %s" % self.run_config['pretrain'])
            try:
                self.load_checkpoint(use_optimizer=True,
                                     use_epoch=True,
                                     use_miou=True)
            except:
                print('load model with channed!!!!!')
                self.load_checkpoint_with_changed(use_optimizer=False,
                                                  use_epoch=False,
                                                  use_miou=False)
        logger.info("start training")

        for epoch in range(self.run_config['start_epoch'],
                           self.run_config['epoch']):
            lr = self.optimizer.param_groups[0]['lr']
            print('epoch=%d, lr=%.8f' % (epoch, lr))
            self.train_epoch(epoch, lr)
            valid_miou = self.valid_epoch(epoch)
            # 确定采用哪一种学习率调节策略
            self.lr_scheduler.LambdaLR_(milestone=5,
                                        gamma=0.92).step(epoch=epoch)
            self.save_checkpoint(epoch, valid_miou, 'last_' + self.global_name)
            if valid_miou > self.run_config['best_miou']:
                cnt = 0
                self.save_checkpoint(epoch, valid_miou,
                                     'best_' + self.global_name)
                logger.info("#############   %d saved   ##############" %
                            epoch)
                self.run_config['best_miou'] = valid_miou
            else:
                cnt += 1
                if cnt == self.run_config['early_stop']:
                    logger.info("early stop")
                    break
        self.writer.close()

    def train_epoch(self, epoch, lr):
        self.metric.reset()
        train_loss = 0.0
        train_miou = 0.0
        tbar = tqdm(self.train_loader)
        self.model.train()
        for i, (image, mask, edge) in enumerate(tbar):
            tbar.set_description('train_miou:%.6f' % train_miou)
            tbar.set_postfix({"train_loss": train_loss})
            image = image.to(self.device)
            mask = mask.to(self.device)
            edge = edge.to(self.device)
            self.optimizer.zero_grad()
            out = self.model(image)
            if isinstance(out, tuple):
                aux_out, final_out = out[0], out[1]
            else:
                aux_out, final_out = None, out
            if self.model_config['model_name'] == 'ocrnet':
                aux_loss = self.Criterion.build_loss(mode='rmi')(aux_out, mask)
                cls_loss = self.Criterion.build_loss(mode='ce')(final_out,
                                                                mask)
                loss = 0.4 * aux_loss + cls_loss
                loss = loss.mean()
            elif self.model_config['model_name'] == 'hrnet_duc':
                loss_body = self.Criterion.build_loss(
                    mode=self.run_config['loss_type'])(final_out, mask)
                loss_edge = self.Criterion.build_loss(mode='dice')(
                    aux_out.squeeze(), edge)
                loss = loss_body + loss_edge
                loss = loss.mean()
            else:
                loss = self.Criterion.build_loss(
                    mode=self.run_config['loss_type'])(final_out, mask)
            loss.backward()
            self.optimizer.step()
            if self.run_config['swa']:
                self.optimizer.swap_swa_sgd()
            with torch.no_grad():
                train_loss = ((train_loss * i) + loss.item()) / (i + 1)
                _, pred = torch.max(final_out, dim=1)
                self.metric.add(pred.cpu().numpy(), mask.cpu().numpy())
                train_miou, train_ious = self.metric.miou()
                train_fwiou = self.metric.fw_iou()
                train_accu = self.metric.pixel_accuracy()
                train_fwaccu = self.metric.pixel_accuracy_class()
        logger.info(
            "Epoch:%2d\t lr:%.8f\t Train loss:%.4f\t Train FWiou:%.4f\t Train Miou:%.4f\t Train accu:%.4f\t "
            "Train fwaccu:%.4f" % (epoch, lr, train_loss, train_fwiou,
                                   train_miou, train_accu, train_fwaccu))
        cls = ""
        ious = list()
        ious_dict = OrderedDict()
        for i, c in enumerate(self.className):
            ious_dict[c] = train_ious[i]
            ious.append(ious_dict[c])
            cls += "%s:" % c + "%.4f "
        ious = tuple(ious)
        logger.info(cls % ious)
        # tensorboard
        self.writer.add_scalar("lr", lr, epoch)
        self.writer.add_scalar("loss/train_loss", train_loss, epoch)
        self.writer.add_scalar("miou/train_miou", train_miou, epoch)
        self.writer.add_scalar("fwiou/train_fwiou", train_fwiou, epoch)
        self.writer.add_scalar("accuracy/train_accu", train_accu, epoch)
        self.writer.add_scalar("fwaccuracy/train_fwaccu", train_fwaccu, epoch)
        self.writer.add_scalars("ious/train_ious", ious_dict, epoch)

    def valid_epoch(self, epoch):
        self.metric.reset()
        valid_loss = 0.0
        valid_miou = 0.0
        tbar = tqdm(self.valid_loader)
        self.model.eval()
        with torch.no_grad():
            for i, (image, mask, edge) in enumerate(tbar):
                tbar.set_description('valid_miou:%.6f' % valid_miou)
                tbar.set_postfix({"valid_loss": valid_loss})
                image = image.to(self.device)
                mask = mask.to(self.device)
                edge = edge.to(self.device)
                out = self.model(image)
                if isinstance(out, tuple):
                    aux_out, final_out = out[0], out[1]
                else:
                    aux_out, final_out = None, out
                if self.model_config['model_name'] == 'ocrnet':
                    aux_loss = self.Criterion.build_loss(mode='rmi')(aux_out,
                                                                     mask)
                    cls_loss = self.Criterion.build_loss(mode='ce')(final_out,
                                                                    mask)
                    loss = 0.4 * aux_loss + cls_loss
                    loss = loss.mean()
                elif self.model_config['model_name'] == 'hrnet_duc':
                    loss_body = self.Criterion.build_loss(
                        mode=self.run_config['loss_type'])(final_out, mask)
                    loss_edge = self.Criterion.build_loss(mode='dice')(
                        aux_out.squeeze(), edge)
                    loss = loss_body + loss_edge
                    # loss = loss.mean()
                else:
                    loss = self.Criterion.build_loss(mode='ce')(final_out,
                                                                mask)
                valid_loss = ((valid_loss * i) + float(loss)) / (i + 1)
                _, pred = torch.max(final_out, dim=1)
                self.metric.add(pred.cpu().numpy(), mask.cpu().numpy())
                valid_miou, valid_ious = self.metric.miou()
                valid_fwiou = self.metric.fw_iou()
                valid_accu = self.metric.pixel_accuracy()
                valid_fwaccu = self.metric.pixel_accuracy_class()
            logger.info(
                "epoch:%d\t valid loss:%.4f\t valid fwiou:%.4f\t valid miou:%.4f valid accu:%.4f\t "
                "valid fwaccu:%.4f\t" % (epoch, valid_loss, valid_fwiou,
                                         valid_miou, valid_accu, valid_fwaccu))
            ious = list()
            cls = ""
            ious_dict = OrderedDict()
            for i, c in enumerate(self.className):
                ious_dict[c] = valid_ious[i]
                ious.append(ious_dict[c])
                cls += "%s:" % c + "%.4f "
            ious = tuple(ious)
            logger.info(cls % ious)
            self.writer.add_scalar("loss/valid_loss", valid_loss, epoch)
            self.writer.add_scalar("miou/valid_miou", valid_miou, epoch)
            self.writer.add_scalar("fwiou/valid_fwiou", valid_fwiou, epoch)
            self.writer.add_scalar("accuracy/valid_accu", valid_accu, epoch)
            self.writer.add_scalar("fwaccuracy/valid_fwaccu", valid_fwaccu,
                                   epoch)
            self.writer.add_scalars("ious/valid_ious", ious_dict, epoch)
        return valid_miou

    def save_checkpoint(self, epoch, best_miou, flag):
        meta = {
            'epoch': epoch,
            'model': self.model.state_dict(),
            'optim': self.optimizer.state_dict(),
            'bmiou': best_miou
        }
        try:
            torch.save(meta,
                       os.path.join(self.run_config['model_save_path'],
                                    '%s.pth' % flag),
                       _use_new_zipfile_serialization=False)
        except:
            torch.save(
                meta,
                os.path.join(self.run_config['model_save_path'],
                             '%s.pth' % flag))

    def load_checkpoint(self, use_optimizer, use_epoch, use_miou):
        state_dict = torch.load(self.run_config['pretrain'],
                                map_location=self.device)
        self.model.load_state_dict(state_dict['model'])
        if use_optimizer:
            self.optimizer.load_state_dict(state_dict['optim'])
        if use_epoch:
            self.run_config['start_epoch'] = state_dict['epoch'] + 1
        if use_miou:
            self.run_config['best_miou'] = state_dict['bmiou']

    def load_checkpoint_with_changed(self, use_optimizer, use_epoch, use_miou):
        state_dict = torch.load(self.run_config['pretrain'],
                                map_location=self.device)
        pretrain_dict = state_dict['model']
        model_dict = self.model.state_dict()
        pretrain_dict = {
            k: v
            for k, v in pretrain_dict.items()
            if k in model_dict and 'edge' not in k
        }
        model_dict.update(pretrain_dict)
        self.model.load_state_dict(model_dict)
        if use_optimizer:
            self.optimizer.load_state_dict(state_dict['optim'])
        if use_epoch:
            self.run_config['start_epoch'] = state_dict['epoch'] + 1
        if use_miou:
            self.run_config['best_miou'] = state_dict['bmiou']
Exemple #28
0
def train(model,
          device,
          trainloader,
          testloader,
          optimizer,
          criterion,
          metric,
          epochs,
          learning_rate,
          swa=True,
          enable_scheduler=True,
          model_arch=''):
    '''
    Function to perform model training.
    '''
    model.to(device)
    steps = 0
    running_loss = 0
    running_metric = 0
    print_every = 100

    train_losses = []
    test_losses = []
    train_metrics = []
    test_metrics = []

    if swa:
        # initialize stochastic weight averaging
        opt = SWA(optimizer)
    else:
        opt = optimizer

    # learning rate cosine annealing
    if enable_scheduler:
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
                                                   len(trainloader),
                                                   eta_min=0.0000001)

    for epoch in range(epochs):

        if enable_scheduler:
            scheduler.step()

        for inputs, labels in trainloader:

            steps += 1
            # Move input and label tensors to the default device
            inputs, labels = inputs.to(device), labels.to(device)

            opt.zero_grad()

            outputs = model.forward(inputs)
            loss = criterion(outputs, labels.float())
            loss.backward()
            opt.step()

            running_loss += loss
            running_metric += metric(outputs, labels.float())

            if steps % print_every == 0:
                test_loss = 0
                test_metric = 0
                model.eval()
                with torch.no_grad():
                    for inputs, labels in testloader:
                        inputs, labels = inputs.to(device), labels.to(device)
                        outputs = model.forward(inputs)

                        test_loss += criterion(outputs, labels.float())

                        test_metric += metric(outputs, labels.float())

                print(f"Epoch {epoch+1}/{epochs}.. "
                      f"Train loss: {running_loss/print_every:.3f}.. "
                      f"Test loss: {test_loss/len(testloader):.3f}.. "
                      f"Train metric: {running_metric/print_every:.3f}.. "
                      f"Test metric: {test_metric/len(testloader):.3f}.. ")

                train_losses.append(running_loss / print_every)
                test_losses.append(test_loss / len(testloader))
                train_metrics.append(running_metric / print_every)
                test_metrics.append(test_metric / len(testloader))

                running_loss = 0
                running_metric = 0

                model.train()
                if swa:
                    opt.update_swa()

        save_model(model,
                   model_arch,
                   learning_rate,
                   epochs,
                   train_losses,
                   test_losses,
                   train_metrics,
                   test_metrics,
                   filepath='models_checkpoints')

    if swa:
        opt.swap_swa_sgd()

    return model, train_losses, test_losses, train_metrics, test_metrics
Exemple #29
0
    def train(self, train_loader, eval_loader, epoch):
        # 定义优化器
        if self.args.swa:
            logger.info('SWA training')
            base_opt = torch.optim.SGD(self.model.parameters(),
                                       lr=self.args.learning_rate)
            optimizer = SWA(base_opt,
                            swa_start=self.args.swa_start,
                            swa_freq=self.args.swa_freq,
                            swa_lr=self.args.swa_lr)
            scheduler = CyclicLR(
                optimizer,
                base_lr=5e-5,
                max_lr=7e-5,
                step_size_up=(self.args.epochs * len(train_loader) /
                              self.args.batch_accumulation),
                cycle_momentum=False)
        else:
            logger.info('Adam training')
            optimizer = torch.optim.Adam(self.model.parameters(),
                                         lr=self.args.learning_rate)
            scheduler = get_linear_schedule_with_warmup(
                optimizer,
                num_warmup_steps=self.args.warmup,
                num_training_steps=(self.args.epochs * len(train_loader) /
                                    self.args.batch_accumulation))

        bar = tqdm(range(self.args.train_steps), total=self.args.train_steps)
        train_batches = cycle(train_loader)
        loss_sum = 0.0
        start = time.time()
        self.model.train()
        for step in bar:
            batch = next(train_batches)
            input_ids, input_mask, segment_ids, label_ids = [
                t.to(self.device) for t in batch
            ]

            loss, _ = self.model(input_ids=input_ids,
                                 token_type_ids=segment_ids,
                                 attention_mask=input_mask,
                                 labels=label_ids)
            if self.gpu_num > 1:
                loss = loss.mean()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            # optimizer.update_swa()
            loss_sum += loss.cpu().item()
            train_loss = loss_sum / (step + 1)

            bar.set_description("loss {}".format(train_loss))
            if (step + 1) % self.args.eval_steps == 0:
                logger.info("***** Training result *****")
                logger.info('  time %.2fs ', time.time() - start)
                logger.info("  %s = %s", 'global_step', str(step + 1))
                logger.info("  %s = %s", 'train loss', str(train_loss))
                # 每eval_steps进行一次evaluate
                self.result = {
                    'epoch': epoch,
                    'global_step': step + 1,
                    'loss': train_loss
                }
                if self.args.swa:
                    optimizer.swap_swa_sgd()
                self.evaluate(eval_loader, epoch)
                if self.args.swa:
                    optimizer.swap_swa_sgd()
        if self.args.swa:
            optimizer.swap_swa_sgd()
        logging.info('The training  of epoch ' + str(epoch + 1) +
                     ' has finished.')
    def __init__(self):

        if args.train is not None:
            self.train_tuple = get_tuple(args.train,
                                         bs=args.batch_size,
                                         shuffle=True,
                                         drop_last=False)

        if args.valid is not None:
            valid_bsize = 2048 if args.multiGPU else 50
            self.valid_tuple = get_tuple(args.valid,
                                         bs=valid_bsize,
                                         shuffle=False,
                                         drop_last=False)
        else:
            self.valid_tuple = None

        # Select Model, X is default
        if args.model == "X":
            self.model = ModelX(args)
        elif args.model == "V":
            self.model = ModelV(args)
        elif args.model == "U":
            self.model = ModelU(args)
        elif args.model == "D":
            self.model = ModelD(args)
        elif args.model == 'O':
            self.model = ModelO(args)
        else:
            print(args.model, " is not implemented.")

        # Load pre-trained weights from paths
        if args.loadpre is not None:
            self.model.load(args.loadpre)

        # GPU options
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        self.model = self.model.cuda()

        # Losses and optimizer
        self.logsoftmax = nn.LogSoftmax(dim=1)
        self.nllloss = nn.NLLLoss()

        if args.train is not None:
            batch_per_epoch = len(self.train_tuple.loader)
            self.t_total = int(batch_per_epoch * args.epochs // args.acc)
            print("Total Iters: %d" % self.t_total)

        def is_backbone(n):
            if "encoder" in n:
                return True
            elif "embeddings" in n:
                return True
            elif "pooler" in n:
                return True
            print("F: ", n)
            return False

        no_decay = ['bias', 'LayerNorm.weight']

        params = list(self.model.named_parameters())
        if args.reg:
            optimizer_grouped_parameters = [
                {
                    "params": [p for n, p in params if is_backbone(n)],
                    "lr": args.lr
                },
                {
                    "params": [p for n, p in params if not is_backbone(n)],
                    "lr": args.lr * 500
                },
            ]

            for n, p in self.model.named_parameters():
                print(n)

            self.optim = AdamW(optimizer_grouped_parameters, lr=args.lr)
        else:
            optimizer_grouped_parameters = [{
                'params':
                [p for n, p in params if not any(nd in n for nd in no_decay)],
                'weight_decay':
                args.wd
            }, {
                'params':
                [p for n, p in params if any(nd in n for nd in no_decay)],
                'weight_decay':
                0.0
            }]

            self.optim = AdamW(optimizer_grouped_parameters, lr=args.lr)

        if args.train is not None:
            self.scheduler = get_linear_schedule_with_warmup(
                self.optim, self.t_total * 0.1, self.t_total)

        self.output = args.output
        os.makedirs(self.output, exist_ok=True)

        # SWA Method:
        if args.contrib:
            self.optim = SWA(self.optim,
                             swa_start=self.t_total * 0.75,
                             swa_freq=5,
                             swa_lr=args.lr)

        if args.swa:
            self.swa_model = AveragedModel(self.model)
            self.swa_start = self.t_total * 0.75
            self.swa_scheduler = SWALR(self.optim, swa_lr=args.lr)