コード例 #1
0
    def __init__(self, cfg_dir: str):
        # load config file and initialize the logger and the device
        self.cfg = get_conf(cfg_dir)
        self.logger = self.init_logger(self.cfg.logger)
        self.device = self.init_device()
        # creating dataset interface and dataloader for trained data
        self.data, self.val_data = self.init_dataloader()
        # create model and initialize its weights and move them to the device
        self.model = self.init_model()
        # initialize the optimizer
        self.optimizer, self.lr_scheduler = self.init_optimizer()
        # define loss function
        self.criterion = torch.nn.CrossEntropyLoss()
        # if resuming, load the checkpoint
        self.if_resume()

        # initialize the early_stopping object
        self.early_stopping = EarlyStopping(
            patience=self.cfg.train_params.patience,
            verbose=True,
            delta=self.cfg.train_params.early_stopping_delta,
        )

        # stochastic weight averaging
        if self.cfg.train_params.epochs > self.cfg.train_params.swa_start:
            self.swa_model = AveragedModel(self.model)
            self.swa_scheduler = SWALR(self.optimizer, **self.cfg.SWA)
コード例 #2
0
    def _configure_optimizers(self, ) -> None:
        """Loads the optimizers."""
        if self._optimizer is not None:
            self._optimizer = self._optimizer(self._network.parameters(),
                                              **self.optimizer_args)
        else:
            self._optimizer = None

        if self._optimizer and self._lr_scheduler is not None:
            if "steps_per_epoch" in self.lr_scheduler_args:
                self.lr_scheduler_args["steps_per_epoch"] = len(
                    self.train_dataloader())

            # Assume lr scheduler should update at each epoch if not specified.
            if "interval" not in self.lr_scheduler_args:
                interval = "epoch"
            else:
                interval = self.lr_scheduler_args.pop("interval")
            self._lr_scheduler = {
                "lr_scheduler":
                self._lr_scheduler(self._optimizer, **self.lr_scheduler_args),
                "interval":
                interval,
            }

        if self.swa_args is not None:
            self._swa_scheduler = {
                "swa_scheduler": SWALR(self._optimizer,
                                       swa_lr=self.swa_args["lr"]),
                "swa_start": self.swa_args["start"],
            }
            self._swa_network = AveragedModel(self._network).to(self.device)
コード例 #3
0
ファイル: model.py プロジェクト: Etzelkut/eye_gaze
 def training_epoch_end(self, outputs):
     self.log('epoch_now',
              self.current_epoch,
              on_step=False,
              on_epoch=True,
              logger=True)
     (oppp) = self.optimizers(use_pl_optimizer=True)
     self.log('lr_now',
              self.get_lr_inside(oppp),
              on_step=False,
              on_epoch=True,
              logger=True)
     # https://github.com/PyTorchLightning/pytorch-lightning/issues/3095
     if self.learning_params["swa"] and (
             self.current_epoch >= self.learning_params["swa_start_epoch"]):
         if self.swa_model is None:
             (optimizer) = self.optimizers(use_pl_optimizer=True)
             print("creating_swa")
             self.swa_model = AveragedModel(self.network)
             self.new_scheduler = SWALR(
                 optimizer,
                 anneal_strategy="linear",
                 anneal_epochs=5,
                 swa_lr=self.learning_params["swa_lr"])
         # https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/
         self.swa_model.update_parameters(self.network)
         self.new_scheduler.step()
コード例 #4
0
ファイル: vol_mdn.py プロジェクト: eadains/SystemsTrading
    def fit_model(self):
        """
        Fits model. Uses AdamW optimizer, model averaging, and a cosine annealing learning rate schedule.
        """
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=0.001)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, 100, 2
        )

        self.swa_model = AveragedModel(self.model)
        swa_start = 750
        swa_scheduler = SWALR(
            optimizer, swa_lr=0.001, anneal_epochs=10, anneal_strategy="cos"
        )

        self.model.train()
        self.swa_model.train()
        for epoch in range(1000):
            optimizer.zero_grad()
            output = self.model(self.x)

            loss = -output.log_prob(self.y.view(-1, 1)).sum()

            loss.backward()
            optimizer.step()

            if epoch > swa_start:
                self.swa_model.update_parameters(self.model)
                swa_scheduler.step()
            else:
                scheduler.step()

            if epoch % 10 == 0:
                print(f"Epoch {epoch} complete. Loss: {loss}")
コード例 #5
0
ファイル: train_prev.py プロジェクト: navidkha/road_mtl
    def __init__(self, cfg_dir: str, data_loader: DataLoader, model,
                 labels_definition):
        self.cfg = get_conf(cfg_dir)
        self._labels_definition = labels_definition
        #TODO
        self.logger = self.init_logger(self.cfg.logger)
        #self.dataset = CustomDataset(**self.cfg.dataset)
        self.data = data_loader
        #self.val_dataset = CustomDatasetVal(**self.cfg.val_dataset)
        #self.val_data = DataLoader(self.val_dataset, **self.cfg.dataloader)
        # self.logger.log_parameters({"tr_len": len(self.dataset),
        #                             "val_len": len(self.val_dataset)})
        self.model = model
        #self.model._resnet.conv1.apply(init_weights_normal)
        self.device = self.cfg.train_params.device
        self.model = self.model.to(device=self.device)
        if self.cfg.train_params.optimizer.lower() == "adam":
            self.optimizer = optim.Adam(self.model.parameters(),
                                        **self.cfg.adam)
        elif self.cfg.train_params.optimizer.lower() == "rmsprop":
            self.optimizer = optim.RMSprop(self.model.parameters(),
                                           **self.cfg.rmsprop)
        else:
            raise ValueError(
                f"Unknown optimizer {self.cfg.train_params.optimizer}")

        self.lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max=100)
        self.criterion = nn.BCELoss()

        if self.cfg.logger.resume:
            # load checkpoint
            print("Loading checkpoint")
            save_dir = self.cfg.directory.load
            checkpoint = load_checkpoint(save_dir, self.device)
            self.model.load_state_dict(checkpoint["model"])
            self.optimizer.load_state_dict(checkpoint["optimizer"])
            self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
            self.epoch = checkpoint["epoch"]
            self.e_loss = checkpoint["e_loss"]
            self.best = checkpoint["best"]
            print(
                f"{datetime.now():%Y-%m-%d %H:%M:%S} "
                f"Loading checkpoint was successful, start from epoch {self.epoch}"
                f" and loss {self.best}")
        else:
            self.epoch = 1
            self.best = np.inf
            self.e_loss = []

        # initialize the early_stopping object
        self.early_stopping = EarlyStopping(
            patience=self.cfg.train_params.patience,
            verbose=True,
            delta=self.cfg.train_params.early_stopping_delta,
        )

        # stochastic weight averaging
        self.swa_model = AveragedModel(self.model)
        self.swa_scheduler = SWALR(self.optimizer, **self.cfg.SWA)
コード例 #6
0
def train(num_epochs, model, data_loader, val_loader, val_every, device, file_name):
    learning_rate = 0.0001
    from torch.optim.swa_utils import AveragedModel, SWALR
    from torch.optim.lr_scheduler import CosineAnnealingLR
    from segmentation_models_pytorch.losses import SoftCrossEntropyLoss, JaccardLoss
    from adamp import AdamP

    criterion = [SoftCrossEntropyLoss(smooth_factor=0.1), JaccardLoss('multiclass', classes=12)]
    optimizer = AdamP(params=model.parameters(), lr=learning_rate, weight_decay=1e-6)
    swa_scheduler = SWALR(optimizer, swa_lr=learning_rate)
    swa_model = AveragedModel(model)
    look = Lookahead(optimizer, la_alpha=0.5)

    print('Start training..')
    best_miou = 0
    for epoch in range(num_epochs):
        hist = np.zeros((12, 12))
        model.train()
        for step, (images, masks, _) in enumerate(data_loader):
            loss = 0
            images = torch.stack(images)  # (batch, channel, height, width)
            masks = torch.stack(masks).long()  # (batch, channel, height, width)

            # gpu 연산을 위해 device 할당
            images, masks = images.to(device), masks.to(device)

            # inference
            outputs = model(images)
            for i in criterion:
                loss += i(outputs, masks)
            # loss 계산 (cross entropy loss)

            look.zero_grad()
            loss.backward()
            look.step()

            outputs = torch.argmax(outputs.squeeze(), dim=1).detach().cpu().numpy()
            hist = add_hist(hist, masks.detach().cpu().numpy(), outputs, n_class=12)
            acc, acc_cls, mIoU, fwavacc = label_accuracy_score(hist)
            # step 주기에 따른 loss, mIoU 출력
            if (step + 1) % 25 == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, mIoU: {:.4f}'.format(
                    epoch + 1, num_epochs, step + 1, len(data_loader), loss.item(), mIoU))

        # validation 주기에 따른 loss 출력 및 best model 저장
        if (epoch + 1) % val_every == 0:
            avrg_loss, val_miou = validation(epoch + 1, model, val_loader, criterion, device)
            if val_miou > best_miou:
                print('Best performance at epoch: {}'.format(epoch + 1))
                print('Save model in', saved_dir)
                best_miou = val_miou
                save_model(model, file_name = file_name)

        if epoch > 3:
            swa_model.update_parameters(model)
            swa_scheduler.step()
コード例 #7
0
def weight_averaging(model_class, checkpoint_paths, data_loader, device):
    from torch.optim.swa_utils import AveragedModel, update_bn

    model = model_class.load_from_checkpoint(checkpoint_paths[0])
    swa_model = AveragedModel(model)

    for path in checkpoint_paths:
        model = model_class.load_from_checkpoint(path)
        swa_model.update_parameters(model)

    swa_model = swa_model.to(device)
    update_bn(data_loader, swa_model, device)
    return swa_model
コード例 #8
0
def average_model_weights(checkpoint_path, average_fn, checkpoint_N):
    checkpoint_files = [
        os.path.join(checkpoint_path, file_name)
        for file_name in os.listdir(checkpoint_path)
        if file_name.endswith(".pt")
    ]

    def ckpt_key(ckpt):
        return int(ckpt.split('_')[-1].split('.')[0])

    try:
        checkpoint_files = sorted(checkpoint_files, key=ckpt_key)
    except:
        logging.warn(
            "Checkpoint names are changed, which may cause inconsistent order."
        )

    # Select the last N checkpoint
    if checkpoint_N > 0 and checkpoint_N <= len(checkpoint_files):
        checkpoint_files = checkpoint_files[-checkpoint_N:]

    # initialize averaged model with first checkpoint
    model = load_model(checkpoint_files[0])
    averaged_model = AveragedModel(model, avg_fn=average_fn)

    # loop through the remaining checkpoints and update averaged model
    for checkpoint in checkpoint_files:
        model = load_model(checkpoint)
        averaged_model.update_parameters(model)

    last_checkpoint = torch.load(checkpoint_files[-1])
    opts = last_checkpoint['opts']
    filename = f'{opts.model}_{opts.data}_{last_checkpoint["epoch"]}_averaged.pt'
    save_path = os.path.join(checkpoint_path, filename)

    if opts.precision[-3:] == ".16":
        model.half()
    else:
        model.float()

    torch.save(
        {
            'epoch': last_checkpoint['epoch'] + 1,
            'model_state_dict': averaged_model.module.state_dict(),
            'loss': 0,  # dummy just to work with validate script
            'train_accuracy': 0,  # dummy just to work with validate script
            'opts': opts
        },
        save_path)

    return averaged_model
コード例 #9
0
    def before_run(self, runner):
        """Construct the averaged model which will keep track of the running
        averages of the parameters of the model."""
        model = runner.model
        self.model = AveragedModel(model)

        self.meta = runner.meta

        if self.meta is None:
            self.meta = dict()
            self.meta.setdefault('hook_msgs', dict())

        if not 'hook_msgs' in self.meta.keys():
            self.meta.setdefault('hook_msgs', dict())
コード例 #10
0
def get_swa(optimizer,
            model,
            swa_lr=0.005,
            anneal_epochs=10,
            anneal_strategy="cos"):
    '''
    SWALR Arguments:
        optimizer (torch.optim.Optimizer): wrapped optimizer
        swa_lr (float or list): the learning rate value for all param groups
            together or separately for each group.
        anneal_epochs (int): number of epochs in the annealing phase 
            (default: 10)
        anneal_strategy (str): "cos" or "linear"; specifies the annealing 
            strategy: "cos" for cosine annealing, "linear" for linear annealing
            (default: "cos")
        last_epoch (int): the index of the last epoch (default: 'cos')
    
    '''
    swa_model = AveragedModel(model)
    # swa_scheduler = SWALR(optimizer, swa_lr=swa_lr)
    # swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, anneal_strategy="linear", anneal_epochs=5, swa_lr=swa_lr)
    swa_scheduler = SWALR(optimizer,
                          swa_lr=swa_lr,
                          anneal_epochs=anneal_epochs,
                          anneal_strategy=anneal_strategy)

    return swa_scheduler, swa_model
コード例 #11
0
ファイル: swalr.py プロジェクト: khirotaka/enchanter
class SWALRRunner(ClassificationRunner):
    def __init__(self, *args, **kwargs):
        super(SWALRRunner, self).__init__(*args, **kwargs)
        self.swa_model = AveragedModel(self.model)
        self.swa_scheduler = SWALR(self.optimizer, swa_lr=0.05)
        self.swa_start = 5

    def update_scheduler(self, epoch: int) -> None:
        if epoch > self.swa_start:
            self.swa_model.update_parameters(self.model)
            self.swa_scheduler.step()

        else:
            super(SWALRRunner, self).update_scheduler(epoch)

    def train_end(self, outputs):
        update_bn(self.loaders["train"], self.swa_model)
        return super(SWALRRunner, self).train_end(outputs)
コード例 #12
0
    def __init__(self, blocks, channels, features, pre_act=False,
                 radix=1, groups=1, bottleneck_width=64,
                 activation=nn.SiLU, squeeze_excitation=False,
                 bottleneck=False, bottleneck_expansion=4,
                 beta=0, val_lambda=0.333, lr=1e-2,
                 use_swa=False, swa_lr=1e-2, swa_freq=250):
        super(Network, self).__init__()
        self.save_hyperparameters()

        self.net = PolicyValueNetwork(
            blocks=blocks, channels=channels, features=features,
            pre_act=pre_act, activation=activation,
            squeeze_excitation=squeeze_excitation,
            bottleneck=bottleneck, bottleneck_expansion=bottleneck_expansion,
            radix=radix, groups=groups, bottleneck_width=bottleneck_width
        )
        if use_swa:
            self.swa_model = AveragedModel(self.net)

        self.ce = nn.CrossEntropyLoss(reduction='none')
        self.bce = nn.BCEWithLogitsLoss()
コード例 #13
0
    def __init__(self, config: DNNConfig):
        self.config = config
        self.epochs = config.epoch_num
        self.device = config.device

        self.model = tmp_model
        #self.criterion = CustomLoss()

        self.criterion = nn.MSELoss()

        optimizer_kwargs = {
            'lr': config.lr,
            'weight_decay': config.weight_decay
        }
        self.sam = config.issam
        self.optimizer = make_optimizer(self.model,
                                        optimizer_kwargs,
                                        optimizer_name=config.optimizer_name,
                                        sam=config.issam)
        self.scheduler_name = config.scheduler_name
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer=self.optimizer, T_max=config.T_max)

        self.isswa = config.getattr('isswa', False)
        self.swa_start = config.getattr('swa_start', 0)

        if config.isswa:
            self.swa_model = AveragedModel(self.model)
            self.swa_scheduler = SWALR(self.optimizer, swa_lr=0.025)

        #self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=self.optimizer,
        #                                                      mode=config.mode, factor=config.factor)

        self.loss_log = {
            'train_loss': [],
            'train_score': [],
            'valid_loss': [],
            'valid_score': []
        }
コード例 #14
0
def build_swa_model(cfg: CfgNode, model: torch.nn.Module,
                    optimizer: torch.optim.Optimizer):
    # Instead of copying weights during initialization, the SWA model copys
    # the model weights when self.update_parameters is first called.
    # https://github.com/pytorch/pytorch/blob/1.7/torch/optim/swa_utils.py#L107

    # The SWA model needs to be constructed for all processes in distributed
    # training, otherwise the training can get stuck.
    swa_model = AveragedModel(model)
    lr = cfg.SOLVER.BASE_LR
    lr *= cfg.SOLVER.SWA.LR_FACTOR
    swa_scheduler = SWALR(optimizer, swa_lr=lr)
    return swa_model, swa_scheduler
コード例 #15
0
    def __init__(self, model, device, config, fold_num):
        self.config = config
        self.epoch = 0
        self.start_epoch = 0
        self.fold_num = fold_num
        if self.config.stage2:
            self.base_dir = f'./result/stage2/{config.dir}/{config.dir}_fold_{config.fold_num}'
        else:
            self.base_dir = f'./result/{config.dir}/{config.dir}_fold_{config.fold_num}'
        os.makedirs(self.base_dir, exist_ok=True)
        self.log_path = f'{self.base_dir}/log.txt'
        self.best_summary_loss = 10**5

        self.model = model
        self.swa_model = AveragedModel(self.model)
        self.device = device
        self.wandb = True

        self.cutmix = self.config.cutmix_ratio
        self.fmix = self.config.fmix_ratio
        self.smix = self.config.smix_ratio

        self.es = EarlyStopping(patience=8)

        self.scaler = GradScaler()
        self.amp = self.config.amp
        param_optimizer = list(self.model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.001
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]

        self.optimizer, self.scheduler = get_optimizer(
            self.model, self.config.optimizer_name,
            self.config.optimizer_params, self.config.scheduler_name,
            self.config.scheduler_params, self.config.n_epochs)

        self.criterion = get_criterion(self.config.criterion_name,
                                       self.config.criterion_params)
        self.log(f'Fitter prepared. Device is {self.device}')
        set_wandb(self.config, fold_num)
コード例 #16
0
    def load_weights(self,
                     network_fn: Optional[Type[nn.Module]] = None) -> None:
        """Load the network weights."""
        logger.debug("Loading network with pretrained weights.")
        filename = glob(self.weights_filename)[0]
        if not filename:
            raise FileNotFoundError(
                f"Could not find any pretrained weights at {self.weights_filename}"
            )
        # Loading state directory.
        state_dict = torch.load(filename,
                                map_location=torch.device(self._device))
        self._network_args = state_dict["network_args"]
        weights = state_dict["model_state"]

        # Initializes the network with trained weights.
        if network_fn is not None:
            self._network = network_fn(**self._network_args)
        self._network.load_state_dict(weights)

        if "swa_network" in state_dict:
            self._swa_network = AveragedModel(self._network).to(self.device)
            self._swa_network.load_state_dict(state_dict["swa_network"])
コード例 #17
0
def train_model(indep_vars, dep_var, verbose=True):
    """
    Trains MDNVol network. Uses AdamW optimizer with cosine annealing learning rate schedule.
    Ouputs averaged model over the last 25% of training epochs.

    indep_vars: n x m torch tensor containing independent variables
        n = number of data points
        m = number of input variables
    dep_var: n x 1 torch tensor containing single dependent variable
        n = number of data points
        1 = single output variable
    """
    model = MDN(indep_vars.shape[1], 1, 250, 5)
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, 100, 2)

    swa_model = AveragedModel(model)
    swa_start = 750
    swa_scheduler = SWALR(optimizer,
                          swa_lr=0.001,
                          anneal_epochs=10,
                          anneal_strategy="cos")

    model.train()
    swa_model.train()
    for epoch in range(1000):
        optimizer.zero_grad()
        output = model(indep_vars)

        loss = -output.log_prob(dep_var).sum()

        loss.backward()
        optimizer.step()

        if epoch > swa_start:
            swa_model.update_parameters(model)
            swa_scheduler.step()
        else:
            scheduler.step()

        if epoch % 10 == 0:
            if verbose:
                print(f"Epoch {epoch} complete. Loss: {loss}")

    swa_model.eval()
    return swa_model
コード例 #18
0
ファイル: weight_avg.py プロジェクト: muzzynine/examples-1
def average_model_weights(checkpoint_path, average_fn):
    checkpoint_files = [
        os.path.join(checkpoint_path, file_name)
        for file_name in os.listdir(checkpoint_path)
        if file_name.endswith(".pt")
    ]

    # initialize averaged model with first checkpoint
    model = load_model(checkpoint_files[0])
    averaged_model = AveragedModel(model, avg_fn=average_fn)

    # loop through the remaining checkpoints and update averaged model
    for checkpoint in checkpoint_files:
        model = load_model(checkpoint)
        averaged_model.update_parameters(model)

    last_checkpoint = torch.load(checkpoint_files[-1])
    opts = last_checkpoint['opts']
    filename = f'{opts.model}_{opts.data}_{last_checkpoint["epoch"]}_averaged.pt'
    save_path = os.path.join(checkpoint_path, filename)

    if opts.precision[-3:] == ".16":
        model.half()
    else:
        model.float()

    torch.save(
        {
            'epoch': last_checkpoint['epoch'] + 1,
            'model_state_dict': averaged_model.module.state_dict(),
            'loss': 0,  # dummy just to work with validate script
            'train_accuracy': 0,  # dummy just to work with validate script
            'opts': opts
        },
        save_path)

    return averaged_model
コード例 #19
0
    def __init__(self, config):
        self.config = config
        self.device =  'cuda' if cuda.is_available() else 'cpu'
        
        self.model = MLP(config)
        self.swa_model = AveragedModel(self.model)

        self.optimizer = make_optimizer(self.model, optimizer_name=self.config.optimizer, sam=self.config.sam)
        self.scheduler = make_scheduler(self.optimizer, decay_name=self.config.scheduler,
                                        num_training_steps=self.config.num_training_steps,
                                        num_warmup_steps=self.config.num_warmup_steps)
        self.swa_start = self.config.swa_start
        self.swa_scheduler = SWALR(self.optimizer, swa_lr=self.config.swa_lr)
        self.epoch_num = 0
        self.criterion = self.config.criterion
コード例 #20
0
ファイル: vol_mdn.py プロジェクト: eadains/SystemsTrading
class MDNVol:
    def __init__(self, x, y):
        """
        Utility class for fitting the mixture density network for forecasting volatility on the SPX.
        Parameters have been tuned for efficieny.
        x: Input data
        y: Output data used for calculating loss function during training
        """
        self.x = torch.Tensor(x.values)
        self.y = torch.Tensor(y.values)
        self.model = MDN(x.shape[1], 1, 250, 5)

    def fit_model(self):
        """
        Fits model. Uses AdamW optimizer, model averaging, and a cosine annealing learning rate schedule.
        """
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=0.001)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, 100, 2
        )

        self.swa_model = AveragedModel(self.model)
        swa_start = 750
        swa_scheduler = SWALR(
            optimizer, swa_lr=0.001, anneal_epochs=10, anneal_strategy="cos"
        )

        self.model.train()
        self.swa_model.train()
        for epoch in range(1000):
            optimizer.zero_grad()
            output = self.model(self.x)

            loss = -output.log_prob(self.y.view(-1, 1)).sum()

            loss.backward()
            optimizer.step()

            if epoch > swa_start:
                self.swa_model.update_parameters(self.model)
                swa_scheduler.step()
            else:
                scheduler.step()

            if epoch % 10 == 0:
                print(f"Epoch {epoch} complete. Loss: {loss}")

    def output_dist(self, x):
        """
        Fits model and returns fitted distribution object.

        Returns: PyTorch MixtureNormal object
        """
        self.fit_model()
        self.swa_model.eval()
        return self.swa_model(torch.Tensor(x).view(1, -1))
コード例 #21
0
def test_1(model_path, output_dir, test_loader, addNDVI):
    in_channels = 4
    if (addNDVI):
        in_channels += 1
    model = smp.UnetPlusPlus(
        encoder_name="resnet101",
        encoder_weights="imagenet",
        in_channels=in_channels,
        classes=10,
    )
    # model = smp.DeepLabV3Plus(
    #         encoder_name="timm-regnety_320", #resnet101
    #         encoder_weights="imagenet",
    #         in_channels=4,
    #         classes=8,
    # )
    # 如果模型是SWA
    if ("swa" in model_path):
        model = AveragedModel(model)
    model.to(DEVICE)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    for image, image_stretch, image_path, ndvi in test_loader:
        with torch.no_grad():
            image = image.cuda()
            image_stretch = image_stretch.cuda()
            output1 = model(image).cpu().data.numpy()
            output2 = model(image_stretch).cpu().data.numpy()
        output = (output1 + output2) / 2.0
        for i in range(output.shape[0]):
            pred = output[i]
            pred = np.argmax(pred, axis=0) + 1
            pred = np.uint8(pred)
            save_path = os.path.join(
                output_dir,
                image_path[i].split('\\')[-1].replace('.tif', '.png'))
            #print(image_path[i][-10:])
            print(save_path)
            cv2.imwrite(save_path, pred)
コード例 #22
0
ファイル: classification_test.py プロジェクト: janosh/aviary
def predict(model_class, test_set, checkpoint_path, device, robust):

    assert isfile(
        checkpoint_path), f"no checkpoint found at '{checkpoint_path}'"
    checkpoint = torch.load(checkpoint_path, map_location=device)

    chk_robust = checkpoint["model_params"]["robust"]
    assert (chk_robust == robust
            ), f"checkpoint['robust'] != robust ({chk_robust} vs {robust})"

    model = model_class(**checkpoint["model_params"], device=device)
    model.to(device)
    model.load_state_dict(checkpoint["state_dict"])

    if "swa" in checkpoint.keys():
        model.swa = checkpoint["swa"]

        model_dict = model.swa["model_state_dict"]
        model.swa["model"] = AveragedModel(model)
        model.swa["model"].load_state_dict(model_dict)

    idx, comp, y_test, output = model.predict(test_set)

    df = pd.DataFrame({"idx": idx, "comp": comp, "y_test": y_test})

    if model.robust:
        mean, log_std = output.chunk(2, dim=1)
        pre_logits_std = torch.exp(log_std).cpu().numpy()
        logits = sampled_softmax(mean, log_std, samples=10).cpu().numpy()
        pre_logits = mean.cpu().numpy()
        for idx, std_al in enumerate(pre_logits_std.T):
            df[f"class_{idx}_std_al"] = std_al

    else:
        pre_logits = output.cpu().numpy()
        logits = softmax(pre_logits, axis=1)

    for idx, (logit, pre_logit) in enumerate(zip(logits.T, pre_logits.T)):
        df[f"class_{idx}_logit"] = logit
        df[f"class_{idx}_pred"] = pre_logit

    return df, y_test, logits, pre_logits
コード例 #23
0
def predict(model_class, test_set, checkpoint_path, device, robust):

    assert isfile(
        checkpoint_path), f"no checkpoint found at '{checkpoint_path}'"
    checkpoint = torch.load(checkpoint_path, map_location=device)

    chk_robust = checkpoint["model_params"]["robust"]
    assert (chk_robust == robust
            ), f"checkpoint['robust'] != robust ({chk_robust} vs  {robust})"

    model = model_class(**checkpoint["model_params"], device=device)
    model.to(device)
    model.load_state_dict(checkpoint["state_dict"])

    normalizer = Normalizer()
    normalizer.load_state_dict(checkpoint["normalizer"])

    if "swa" in checkpoint.keys():
        model.swa = checkpoint["swa"]

        model_dict = model.swa["model_state_dict"]
        model.swa["model"] = AveragedModel(model)
        model.swa["model"].load_state_dict(model_dict)

    idx, comp, y_test, output = model.predict(test_set)

    df = pd.DataFrame({"idx": idx, "comp": comp, "y_test": y_test})

    output = output.cpu().squeeze(
    )  # move preds to CPU in case model ran on GPU
    if robust:
        mean, log_std_al = (x.squeeze() for x in output.chunk(2, dim=1))
        df["pred"] = normalizer.denorm(mean).numpy()
        df["std_al"] = (log_std_al.exp() * normalizer.std).numpy()
    else:
        df["pred"] = normalizer.denorm(output).numpy()

    return df
コード例 #24
0
def DecoderTensorWriting(model_weight_path,
                         decoder_img_output_path,
                         image_root_path,
                         imageNames,
                         if_swa=True):
    device = "cuda:1"
    model = EncoderDecoderNet(inChannels=3,
                              encodedDimension=encodedDimension,
                              drop_ratio=0,
                              layersExpandRatio=layersExpandRatio,
                              channelsExpandRatio=channelsExpandRatio,
                              blockExpandRatio=blockExpandRatio,
                              encoderImgHeight=12,
                              encoderImgWidth=52,
                              ch=12,
                              if_add_plate_infor=True).to(device)
    if if_swa:
        model = AveragedModel(model)
    model.load_state_dict(torch.load(model_weight_path))
    model = model.eval()
    transformer = tv.transforms.Compose([tv.transforms.ToTensor()])
    for i, nameD in enumerate(imageNames):
        imgD = Image.open(os.path.join(image_root_path, nameD)).convert("RGB")
        print("Decoder : ", i)
        print(nameD)
        tImg = transformer(imgD).unsqueeze(dim=0).to(device)
        if if_add_plate_information:
            decoderTensor, encoderT = model(
                tImg,
                torch.from_numpy(np.array([img2Plates[nameD]
                                           ])).float().to(device))
        else:
            decoderTensor, encoderT = model(tImg, None)
        #print(encoderT)
        decoder = torch.sigmoid(decoderTensor).detach().cpu().squeeze(dim=0)
        decoderImg = tv.transforms.ToPILImage()(decoder)
        decoderImg.save(os.path.join(decoder_img_output_path, nameD))
コード例 #25
0
def EncoderTensorWriting(model_weight_path,
                         write_path,
                         image_root_path,
                         imageNames,
                         if_swa=True):
    device = "cuda:1"
    model = EncoderDecoderNet(inChannels=3,
                              encodedDimension=encodedDimension,
                              drop_ratio=0,
                              layersExpandRatio=layersExpandRatio,
                              channelsExpandRatio=channelsExpandRatio,
                              blockExpandRatio=blockExpandRatio,
                              encoderImgHeight=12,
                              encoderImgWidth=52,
                              ch=12,
                              if_add_plate_infor=True).to(device)
    if if_swa:
        model = AveragedModel(model)
    model.load_state_dict(torch.load(model_weight_path))
    model = model.eval()
    transformer = tv.transforms.Compose([tv.transforms.ToTensor()])
    for i, nameE in enumerate(imageNames):
        imgE = Image.open(os.path.join(image_root_path, nameE)).convert("RGB")
        print("Encoder : ", i)
        print(nameE)
        tImg = transformer(imgE).unsqueeze(dim=0).to(device)
        if if_add_plate_information:
            _, encoderTensor = model(
                tImg,
                torch.from_numpy(np.array([img2Plates[nameE]
                                           ])).float().to(device))
        else:
            _, encoderTensor = model(tImg, None)
        encoderTensor = encoderTensor.detach().cpu().numpy()
        encoderTensor = np.squeeze(encoderTensor, axis=0)
        np.save(os.path.join(write_path, nameE), encoderTensor)
コード例 #26
0
    trainTransforms = tv.transforms.Compose([
        tv.transforms.RandomHorizontalFlip(p=0.5),
        tv.transforms.RandomVerticalFlip(p=0.5),
        tv.transforms.RandomApply(
            [tv.transforms.RandomCrop(size=randomCropSize)], p=0.5),
        tv.transforms.RandomApply([tv.transforms.RandomRotation(degrees=60)],
                                  p=0.5),
        tv.transforms.Resize(size=inputImageSize),
        tv.transforms.ToTensor(),
        #tv.transforms.RandomErasing(p=0.2, scale=(0.1, 0.15), ratio=(0.1, 1.))
    ])
    testTransforms = tv.transforms.Compose(
        [tv.transforms.Resize(size=inputImageSize),
         tv.transforms.ToTensor()])
    model = tv.models.resnet50(num_classes=3).to(device)
    swa_model = AveragedModel(model)

    if trainOrTest.lower() == "train":
        ### Optimizer
        optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=1e-5)
        cosine_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                                epoch,
                                                                eta_min=0,
                                                                last_epoch=-1)
        scheduler = GradualWarmupScheduler(optimizer,
                                           multiplier=multiplier,
                                           total_epoch=warmEpoch,
                                           after_scheduler=cosine_scheduler)
        swa_scheduler = SWALR(optimizer,
                              swa_lr=LR,
                              anneal_epochs=15,
コード例 #27
0
class SWAHook(Hook):
    r"""SWA Object Detection Hook.
        This hook works together with SWA training config files to train
        SWA object detectors <https://arxiv.org/abs/2012.12645>.
        Args:
            swa_eval (bool): Whether to evaluate the swa model.
                Defaults to True.
            eval_hook (Hook): Hook class that contains evaluation functions.
                Defaults to None.
    """
    def __init__(self, swa_eval=True, eval_hook=None):
        if not isinstance(swa_eval, bool):
            raise TypeError('swa_eval must be a bool, but got'
                            f'{type(swa_eval)}')
        if swa_eval:
            if not isinstance(eval_hook, EvalHook) or \
                   isinstance(eval_hook, DistEvalHook):
                raise TypeError('eval_hook must be either a EvalHook or a '
                                'DistEvalHook when swa_eval = True, but got'
                                f'{type(eval_hook)}')
        self.swa_eval = swa_eval
        self.eval_hook = eval_hook

    def before_run(self, runner):
        """Construct the averaged model which will keep track of the running
        averages of the parameters of the model."""
        model = runner.model
        self.model = AveragedModel(model)

        self.meta = runner.meta

        if self.meta is None:
            self.meta = dict()
            self.meta.setdefault('hook_msgs', dict())

        if not 'hook_msgs' in self.meta.keys():
            self.meta.setdefault('hook_msgs', dict())

    @master_only
    def _save_ckpt(self, model, filepath, meta, runner):
        save_checkpoint(model, filepath, runner.optimizer, meta)

        for i in range(20):
            try:
                ckpt = torch.load(filepath, map_location='cpu')
                runner.logger.info(
                    f'Success Saving swa model at swa-training {runner.epoch + 1} epoch'
                )
                break
            except Exception as e:
                save_checkpoint(model, filepath, runner.optimizer, meta)
                continue

    def after_train_epoch(self, runner):
        """Update the parameters of the averaged model, save and evaluate the
        updated averaged model."""
        model = runner.model
        # update the parameters of the averaged model
        self.model.update_parameters(model)

        # save the swa model
        runner.logger.info(
            f'Saving swa model at swa-training {runner.epoch + 1} epoch')
        filename = 'swa_model_{}.pth'.format(runner.epoch + 1)
        filepath = osp.join(runner.work_dir, filename)
        optimizer = runner.optimizer
        self.meta['hook_msgs']['last_ckpt'] = filepath
        self._save_ckpt(self.model.module, filepath, self.meta, runner)

        # evaluate the swa model
        if self.swa_eval:
            self.work_dir = runner.work_dir
            self.rank = runner.rank
            self.epoch = runner.epoch
            self.logger = runner.logger
            self.log_buffer = runner.log_buffer
            self.meta['hook_msgs']['last_ckpt'] = filename
            self.eval_hook.after_train_epoch(self)

    def after_run(self, runner):
        # since BN layers in the backbone are frozen,
        # we do not need to update the BN for the swa model
        pass

    def before_epoch(self, runner):
        pass
コード例 #28
0
class Network(pl.LightningModule):
    # noinspection PyUnusedLocal
    def __init__(self, blocks, channels, features, pre_act=False,
                 radix=1, groups=1, bottleneck_width=64,
                 activation=nn.SiLU, squeeze_excitation=False,
                 bottleneck=False, bottleneck_expansion=4,
                 beta=0, val_lambda=0.333, lr=1e-2,
                 use_swa=False, swa_lr=1e-2, swa_freq=250):
        super(Network, self).__init__()
        self.save_hyperparameters()

        self.net = PolicyValueNetwork(
            blocks=blocks, channels=channels, features=features,
            pre_act=pre_act, activation=activation,
            squeeze_excitation=squeeze_excitation,
            bottleneck=bottleneck, bottleneck_expansion=bottleneck_expansion,
            radix=radix, groups=groups, bottleneck_width=bottleneck_width
        )
        if use_swa:
            self.swa_model = AveragedModel(self.net)

        self.ce = nn.CrossEntropyLoss(reduction='none')
        self.bce = nn.BCEWithLogitsLoss()

    def load_pretrained_value(self, pretrained_model_path):
        copy_pretrained_value(pretrained_model_path=pretrained_model_path,
                              model=self.net.base)

    def forward(self, x):
        policy, value = self.net(x)
        return policy, value

    def training_step(self, batch, batch_idx):
        x1, x2, t1, t2, z, value = batch
        y1, y2 = self((x1, x2))

        t1 = t1.view(-1)
        weight = torch.where(value == -1.0, 0.0, 1.0)

        loss1 = (self.ce(input=y1, target=t1) * z).mean()
        if self.hparams.beta > 0:
            entropy = F.softmax(y1, dim=1) * F.log_softmax(y1, dim=1)
            loss1 += self.hparams.beta * entropy.sum(dim=1).mean()
        loss2 = self.bce(input=y2, target=t2)
        loss3 = F.binary_cross_entropy_with_logits(
            input=y2, target=value, weight=weight
        )
        loss = (loss1 + (1 - self.hparams.val_lambda) * loss2 +
                self.hparams.val_lambda * loss3)

        self.log_dict({
            'loss': loss, 'loss/1': loss1, 'loss/2': loss2, 'loss/3': loss3,
            'accuracy/1':
                (torch.max(y1, dim=1)[1] == t1).type(torch.float32).mean(),
            'accuracy/2':
                ((y2 >= 0) == (t2 >= 0.5)).type(torch.float32).mean()
        })

        return loss

    def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int,
                           dataloader_idx: int) -> None:
        if not self.hparams.use_swa:
            return

        if (batch_idx + 1) % self.hparams.swa_freq == 0:
            self.swa_model.update_parameters(self.net)
            self.swa_scheduler.step()

    def validation_step(self, batch, batch_idx):
        x1, x2, t1, t2, z, value = batch
        y1, y2 = self((x1, x2))

        t1 = t1.view(-1)

        loss1 = (self.ce(input=y1, target=t1) * z).mean()
        loss2 = self.bce(input=y2, target=t2)
        loss3 = self.bce(input=y2, target=value)
        loss = (loss1 + (1 - self.hparams.val_lambda) * loss2 +
                self.hparams.val_lambda * loss3)

        entropy1 = (-F.softmax(y1, dim=1) *
                    F.log_softmax(y1, dim=1)).sum(dim=1)
        p2 = y2.sigmoid()
        log1p_ey2 = F.softplus(y2)
        entropy2 = -(p2 * (y2 - log1p_ey2) + (1 - p2) * -log1p_ey2)

        result = {
            'val_loss': loss, 'val_loss/1': loss1, 'val_loss/2': loss2,
            'val_loss/3': loss3,
            'val_accuracy/1':
                (torch.max(y1, dim=1)[1] == t1).type(torch.float32).mean(),
            'val_accuracy/2':
                ((y2 >= 0) == (t2 >= 0.5)).type(torch.float32).mean(),
            'val_entropy/1': entropy1.mean(), 'val_entropy/2': entropy2.mean()
        }
        self.log_dict(result)

        return result

    def test_step(self, batch, batch_idx):
        x1, x2, t1, t2, z, value = batch
        if self.hparams.use_swa:
            y1, y2 = self.swa_model((x1, x2))
        else:
            y1, y2 = self((x1, x2))
        t1 = t1.view(-1)

        loss1 = (self.ce(input=y1, target=t1) * z).mean()
        loss2 = self.bce(input=y2, target=t2)
        loss3 = self.bce(input=y2, target=value)
        loss = (loss1 + (1 - self.hparams.val_lambda) * loss2 +
                self.hparams.val_lambda * loss3)

        entropy1 = (-F.softmax(y1, dim=1) *
                    F.log_softmax(y1, dim=1)).sum(dim=1)
        p2 = y2.sigmoid()
        log1p_ey2 = F.softplus(y2)
        entropy2 = -(p2 * (y2 - log1p_ey2) + (1 - p2) * -log1p_ey2)

        result = {
            'test_loss': loss, 'test_loss/1': loss1, 'test_loss/2': loss2,
            'test_loss/3': loss3,
            'test_accuracy/1':
                (torch.max(y1, dim=1)[1] == t1).type(torch.float32).mean(),
            'test_accuracy/2':
                ((y2 >= 0) == (t2 >= 0.5)).type(torch.float32).mean(),
            'test_entropy/1': entropy1.mean(),
            'test_entropy/2': entropy2.mean()
        }
        self.log_dict(result)

        return result

    # noinspection PyAttributeOutsideInit
    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr)
        if self.hparams.use_swa:
            self.swa_scheduler = SWALR(
                optimizer, swa_lr=self.hparams.swa_lr,
                anneal_strategy='linear', anneal_epochs=10
            )
        return optimizer
コード例 #29
0
    def __init__(self):

        if args.train is not None:
            self.train_tuple = get_tuple(args.train,
                                         bs=args.batch_size,
                                         shuffle=True,
                                         drop_last=False)

        if args.valid is not None:
            valid_bsize = 2048 if args.multiGPU else 50
            self.valid_tuple = get_tuple(args.valid,
                                         bs=valid_bsize,
                                         shuffle=False,
                                         drop_last=False)
        else:
            self.valid_tuple = None

        # Select Model, X is default
        if args.model == "X":
            self.model = ModelX(args)
        elif args.model == "V":
            self.model = ModelV(args)
        elif args.model == "U":
            self.model = ModelU(args)
        elif args.model == "D":
            self.model = ModelD(args)
        elif args.model == 'O':
            self.model = ModelO(args)
        else:
            print(args.model, " is not implemented.")

        # Load pre-trained weights from paths
        if args.loadpre is not None:
            self.model.load(args.loadpre)

        # GPU options
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        self.model = self.model.cuda()

        # Losses and optimizer
        self.logsoftmax = nn.LogSoftmax(dim=1)
        self.nllloss = nn.NLLLoss()

        if args.train is not None:
            batch_per_epoch = len(self.train_tuple.loader)
            self.t_total = int(batch_per_epoch * args.epochs // args.acc)
            print("Total Iters: %d" % self.t_total)

        def is_backbone(n):
            if "encoder" in n:
                return True
            elif "embeddings" in n:
                return True
            elif "pooler" in n:
                return True
            print("F: ", n)
            return False

        no_decay = ['bias', 'LayerNorm.weight']

        params = list(self.model.named_parameters())
        if args.reg:
            optimizer_grouped_parameters = [
                {
                    "params": [p for n, p in params if is_backbone(n)],
                    "lr": args.lr
                },
                {
                    "params": [p for n, p in params if not is_backbone(n)],
                    "lr": args.lr * 500
                },
            ]

            for n, p in self.model.named_parameters():
                print(n)

            self.optim = AdamW(optimizer_grouped_parameters, lr=args.lr)
        else:
            optimizer_grouped_parameters = [{
                'params':
                [p for n, p in params if not any(nd in n for nd in no_decay)],
                'weight_decay':
                args.wd
            }, {
                'params':
                [p for n, p in params if any(nd in n for nd in no_decay)],
                'weight_decay':
                0.0
            }]

            self.optim = AdamW(optimizer_grouped_parameters, lr=args.lr)

        if args.train is not None:
            self.scheduler = get_linear_schedule_with_warmup(
                self.optim, self.t_total * 0.1, self.t_total)

        self.output = args.output
        os.makedirs(self.output, exist_ok=True)

        # SWA Method:
        if args.contrib:
            self.optim = SWA(self.optim,
                             swa_start=self.t_total * 0.75,
                             swa_freq=5,
                             swa_lr=args.lr)

        if args.swa:
            self.swa_model = AveragedModel(self.model)
            self.swa_start = self.t_total * 0.75
            self.swa_scheduler = SWALR(self.optim, swa_lr=args.lr)
コード例 #30
0
class HM:
    def __init__(self):

        if args.train is not None:
            self.train_tuple = get_tuple(args.train,
                                         bs=args.batch_size,
                                         shuffle=True,
                                         drop_last=False)

        if args.valid is not None:
            valid_bsize = 2048 if args.multiGPU else 50
            self.valid_tuple = get_tuple(args.valid,
                                         bs=valid_bsize,
                                         shuffle=False,
                                         drop_last=False)
        else:
            self.valid_tuple = None

        # Select Model, X is default
        if args.model == "X":
            self.model = ModelX(args)
        elif args.model == "V":
            self.model = ModelV(args)
        elif args.model == "U":
            self.model = ModelU(args)
        elif args.model == "D":
            self.model = ModelD(args)
        elif args.model == 'O':
            self.model = ModelO(args)
        else:
            print(args.model, " is not implemented.")

        # Load pre-trained weights from paths
        if args.loadpre is not None:
            self.model.load(args.loadpre)

        # GPU options
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        self.model = self.model.cuda()

        # Losses and optimizer
        self.logsoftmax = nn.LogSoftmax(dim=1)
        self.nllloss = nn.NLLLoss()

        if args.train is not None:
            batch_per_epoch = len(self.train_tuple.loader)
            self.t_total = int(batch_per_epoch * args.epochs // args.acc)
            print("Total Iters: %d" % self.t_total)

        def is_backbone(n):
            if "encoder" in n:
                return True
            elif "embeddings" in n:
                return True
            elif "pooler" in n:
                return True
            print("F: ", n)
            return False

        no_decay = ['bias', 'LayerNorm.weight']

        params = list(self.model.named_parameters())
        if args.reg:
            optimizer_grouped_parameters = [
                {
                    "params": [p for n, p in params if is_backbone(n)],
                    "lr": args.lr
                },
                {
                    "params": [p for n, p in params if not is_backbone(n)],
                    "lr": args.lr * 500
                },
            ]

            for n, p in self.model.named_parameters():
                print(n)

            self.optim = AdamW(optimizer_grouped_parameters, lr=args.lr)
        else:
            optimizer_grouped_parameters = [{
                'params':
                [p for n, p in params if not any(nd in n for nd in no_decay)],
                'weight_decay':
                args.wd
            }, {
                'params':
                [p for n, p in params if any(nd in n for nd in no_decay)],
                'weight_decay':
                0.0
            }]

            self.optim = AdamW(optimizer_grouped_parameters, lr=args.lr)

        if args.train is not None:
            self.scheduler = get_linear_schedule_with_warmup(
                self.optim, self.t_total * 0.1, self.t_total)

        self.output = args.output
        os.makedirs(self.output, exist_ok=True)

        # SWA Method:
        if args.contrib:
            self.optim = SWA(self.optim,
                             swa_start=self.t_total * 0.75,
                             swa_freq=5,
                             swa_lr=args.lr)

        if args.swa:
            self.swa_model = AveragedModel(self.model)
            self.swa_start = self.t_total * 0.75
            self.swa_scheduler = SWALR(self.optim, swa_lr=args.lr)

    def train(self, train_tuple, eval_tuple):

        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)

        print("Batches:", len(loader))

        self.optim.zero_grad()

        best_roc = 0.
        ups = 0

        total_loss = 0.

        for epoch in range(args.epochs):

            if args.reg:
                if args.model != "X":
                    print(self.model.model.layer_weights)

            id2ans = {}
            id2prob = {}

            for i, (ids, feats, boxes, sent,
                    target) in iter_wrapper(enumerate(loader)):

                if ups == args.midsave:
                    self.save("MID")

                self.model.train()

                if args.swa:
                    self.swa_model.train()

                feats, boxes, target = feats.cuda(), boxes.cuda(), target.long(
                ).cuda()

                # Model expects visual feats as tuple of feats & boxes
                logit = self.model(sent, (feats, boxes))

                # Note: LogSoftmax does not change order, hence there should be nothing wrong with taking it as our prediction
                # In fact ROC AUC stays the exact same for logsoftmax / normal softmax, but logsoftmax is better for loss calculation
                # due to stronger penalization & decomplexifying properties (log(a/b) = log(a) - log(b))
                logit = self.logsoftmax(logit)
                score = logit[:, 1]

                if i < 1:
                    print(logit[0, :].detach())

                # Note: This loss is the same as CrossEntropy (We splitted it up in logsoftmax & neg. log likelihood loss)
                loss = self.nllloss(logit.view(-1, 2), target.view(-1))

                # Scaling loss by batch size, as we have batches with different sizes, since we do not "drop_last" & dividing by acc for accumulation
                # Not scaling the loss will worsen performance by ~2abs%
                loss = loss * logit.size(0) / args.acc
                loss.backward()

                total_loss += loss.detach().item()

                # Acts as argmax - extracting the higher score & the corresponding index (0 or 1)
                _, predict = logit.detach().max(1)
                # Getting labels for accuracy
                for qid, l in zip(ids, predict.cpu().numpy()):
                    id2ans[qid] = l
                # Getting probabilities for Roc auc
                for qid, l in zip(ids, score.detach().cpu().numpy()):
                    id2prob[qid] = l

                if (i + 1) % args.acc == 0:

                    nn.utils.clip_grad_norm_(self.model.parameters(),
                                             args.clip)

                    self.optim.step()

                    if (args.swa) and (ups > self.swa_start):
                        self.swa_model.update_parameters(self.model)
                        self.swa_scheduler.step()
                    else:
                        self.scheduler.step()
                    self.optim.zero_grad()

                    ups += 1

                    # Do Validation in between
                    if ups % 250 == 0:

                        log_str = "\nEpoch(U) %d(%d): Train AC %0.2f RA %0.4f LOSS %0.4f\n" % (
                            epoch, ups, evaluator.evaluate(id2ans) * 100,
                            evaluator.roc_auc(id2prob) * 100, total_loss)

                        # Set loss back to 0 after printing it
                        total_loss = 0.

                        if self.valid_tuple is not None:  # Do Validation
                            acc, roc_auc = self.evaluate(eval_tuple)
                            if roc_auc > best_roc:
                                best_roc = roc_auc
                                best_acc = acc
                                # Only save BEST when no midsave is specified to save space
                                #if args.midsave < 0:
                                #    self.save("BEST")

                            log_str += "\nEpoch(U) %d(%d): DEV AC %0.2f RA %0.4f \n" % (
                                epoch, ups, acc * 100., roc_auc * 100)
                            log_str += "Epoch(U) %d(%d): BEST AC %0.2f RA %0.4f \n" % (
                                epoch, ups, best_acc * 100., best_roc * 100.)

                        print(log_str, end='')

                        with open(self.output + "/log.log", 'a') as f:
                            f.write(log_str)
                            f.flush()

        if (epoch + 1) == args.epochs:
            if args.contrib:
                self.optim.swap_swa_sgd()

        self.save("LAST" + args.train)

    def predict(self, eval_tuple: DataTuple, dump=None, out_csv=True):

        dset, loader, evaluator = eval_tuple
        id2ans = {}
        id2prob = {}

        for i, datum_tuple in enumerate(loader):

            ids, feats, boxes, sent = datum_tuple[:4]

            self.model.eval()

            if args.swa:
                self.swa_model.eval()

            with torch.no_grad():

                feats, boxes = feats.cuda(), boxes.cuda()
                logit = self.model(sent, (feats, boxes))

                # Note: LogSoftmax does not change order, hence there should be nothing wrong with taking it as our prediction
                logit = self.logsoftmax(logit)
                score = logit[:, 1]

                if args.swa:
                    logit = self.swa_model(sent, (feats, boxes))
                    logit = self.logsoftmax(logit)

                _, predict = logit.max(1)

                for qid, l in zip(ids, predict.cpu().numpy()):
                    id2ans[qid] = l

                # Getting probas for Roc Auc
                for qid, l in zip(ids, score.cpu().numpy()):
                    id2prob[qid] = l

        if dump is not None:
            if out_csv == True:
                evaluator.dump_csv(id2ans, id2prob, dump)
            else:
                evaluator.dump_result(id2ans, dump)

        return id2ans, id2prob

    def evaluate(self, eval_tuple: DataTuple, dump=None):
        """Evaluate all data in data_tuple."""
        id2ans, id2prob = self.predict(eval_tuple, dump=dump)

        acc = eval_tuple.evaluator.evaluate(id2ans)
        roc_auc = eval_tuple.evaluator.roc_auc(id2prob)

        return acc, roc_auc

    def save(self, name):
        if args.swa:
            torch.save(self.swa_model.state_dict(),
                       os.path.join(self.output, "%s.pth" % name))
        else:
            torch.save(self.model.state_dict(),
                       os.path.join(self.output, "%s.pth" % name))

    def load(self, path):
        print("Load model from %s" % path)

        state_dict = torch.load("%s" % path)
        new_state_dict = {}
        for key, value in state_dict.items():
            # N_averaged is a key in SWA models we cannot load, so we skip it
            if key.startswith("n_averaged"):
                print("n_averaged:", value)
                continue
            # SWA Models will start with module
            if key.startswith("module."):
                new_state_dict[key[len("module."):]] = value
            else:
                new_state_dict[key] = value
        state_dict = new_state_dict
        self.model.load_state_dict(state_dict)