Exemplo n.º 1
0
    def check_parity(amp: bool, manual_reduction: bool):

        # The API should be the exact same in between the sharded and non-sharded variants, generic closure
        def closure(model,
                    scaler,
                    input_tensor,
                    should_accumulate,
                    _manual_reduction=False):
            accumulate_steps = 3 if should_accumulate else 1

            model.zero_grad()

            def step():
                if scaler is not None:
                    with torch.cuda.amp.autocast():
                        loss = model(input_tensor).abs().sum()
                        scaler.scale(loss).backward()
                else:
                    loss = model(input_tensor).abs().sum()
                    loss.backward()

            with model.no_sync() if should_accumulate else suppress():
                for _ in range(accumulate_steps - 1):
                    step()

            if not _manual_reduction:
                step()
            else:
                with model.no_sync():
                    step()

                model.reduce()

        # Any model works. Add one different buffer per rank
        model = _get_mlp()
        model.register_buffer("test_buffer", torch.ones((1)) * rank)
        model.to(device)

        # Make sure that the model starts with non-trainable, so that we check for the buckets to be
        # properly reassigned when/if this changes
        next(model.parameters()).requires_grad = False

        sharded_optimizer = OSS(params=model.parameters(),
                                optim=torch.optim.SGD,
                                lr=1e-4,
                                momentum=0.99)
        sharded_ddp_model = ShardedDataParallel(
            module=model,
            sharded_optimizer=sharded_optimizer,
            broadcast_buffers=True,
            reduce_buffer_size=reduce_buffer_size,
            reduce_fp16=fp16_reduction,
        )

        ddp_model_single = copy.deepcopy(model)
        ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(),
                                        lr=1e-4,
                                        momentum=0.99)
        ddp_model = DDP(ddp_model_single,
                        device_ids=[rank],
                        broadcast_buffers=True,
                        find_unused_parameters=True)

        if fp16_reduction:
            from dist.algorithms.ddp_com_hooks.default_hooks import fp16_compress_hook

            ddp_model.register_comm_hook(
                state=None, hook=fp16_compress_hook)  # type: ignore

        ddp_scaler = TorchGradScaler() if amp else None
        sharded_ddp_scaler = ShardedGradScaler() if amp else None

        # The model should be synchronized in between the ranks at construction time, check that
        check_same_model_params(sharded_ddp_model, ddp_model)

        # Typical training loop, check that we get the exact same results as DDP
        for i in range(NUMBER_BATCHS):
            input_tensor = torch.rand((BATCH_SIZE, 2)).to(device)

            def closure_ddp(input_tensor=input_tensor):
                return closure(ddp_model, ddp_scaler, input_tensor,
                               grad_accumulation)

            def closure_sharded(input_tensor=input_tensor):
                return closure(
                    sharded_ddp_model,
                    sharded_ddp_scaler,
                    input_tensor,
                    grad_accumulation,
                    _manual_reduction=manual_reduction,
                )

            # Step/scale both
            if ddp_scaler is not None:
                _ = closure_ddp(input_tensor)
                ddp_scaler.step(ddp_optimizer)
                ddp_scaler.update()
            else:
                ddp_optimizer.step(closure=closure_ddp)

            if sharded_ddp_scaler is not None:
                _ = closure_sharded(input_tensor)
                sharded_ddp_scaler.step(sharded_optimizer)
                sharded_ddp_scaler.update()
            else:
                sharded_optimizer.step(closure=closure_sharded)

            check_same_model_params(sharded_ddp_model, ddp_model,
                                    f"Rank: {rank} - Step {i} broke")

            # Flip the trainability of the first parameter back and forth
            if i == 0 and change_train_graph:
                next(sharded_ddp_model.parameters()).requires_grad = not next(
                    sharded_ddp_model.parameters()).requires_grad
                next(ddp_model.parameters()).requires_grad = not next(
                    ddp_model.parameters()).requires_grad
                check_same_model_params(
                    sharded_ddp_model, ddp_model,
                    f"Rank: {rank} - Trainability refresh {i} broke")
Exemplo n.º 2
0
 def _maybe_init_amp(self):
     if self.fp16 and self.amp_grad_scaler is None and torch.cuda.is_available(
     ):
         self.amp_grad_scaler = GradScaler()
Exemplo n.º 3
0
def train(model,
          dataloaders: dict,
          criterion,
          optimizer,
          metrics,
          scheduler,
          reconstructor,
          rundir: Union[str, bytes, os.PathLike],
          stopper,
          device: torch.device,
          num_epochs: int = 1000,
          steps_per_epoch: int = 1000,
          steps_per_validation_epoch: int = 1000,
          steps_per_test_epoch: int = 100,
          early_stopping_begins: int = 0,
          max_flow: float = 2.5,
          dali: bool = False,
          fp16: bool = False):
    # check our inputs
    assert (isinstance(model, nn.Module))
    assert (isinstance(criterion, nn.Module))
    assert (isinstance(optimizer, torch.optim.Optimizer))

    scaler = None
    if fp16:
        scaler = GradScaler()
    # loop over number of epochs!
    for epoch in trange(0, num_epochs):
        # if our learning rate scheduler plateaus when validation metric saturates, we have to pass our "key metric" for
        # our validation set. Else, just step every epoch
        if scheduler.name == 'plateau' and epoch > 0:
            if hasattr(metrics, 'latest_key'):
                if 'val' in list(metrics.latest_key.keys()):
                    scheduler.step(metrics.latest_key['val'])
        elif epoch > 0:
            scheduler.step()
        # update the learning rate for this epoch
        min_lr = utils.get_minimum_learning_rate(optimizer)
        # store the learning rate for this epoch in our metrics file
        # print('min lr: {}'.format(min_lr))
        metrics.update_lr(min_lr)

        # loop over our training set!
        model, metrics, _ = loop_one_epoch(dataloaders['train'],
                                           model,
                                           criterion,
                                           optimizer,
                                           metrics,
                                           reconstructor,
                                           steps_per_epoch,
                                           train_mode=True,
                                           device=device,
                                           dali=dali,
                                           fp16=fp16,
                                           scaler=scaler)

        # evaluate on validation set
        with torch.no_grad():
            model, metrics, examples = loop_one_epoch(
                dataloaders['val'],
                model,
                criterion,
                optimizer,
                metrics,
                reconstructor,
                steps_per_validation_epoch,
                train_mode=False,
                device=device,
                max_flow=max_flow,
                dali=dali,
                fp16=fp16,
                scaler=scaler)

            # some training protocols do not have test sets, so just reuse validation set for testing inference speed
            key = 'test' if 'test' in dataloaders.keys() else 'val'
            loader = dataloaders[key]
            # evaluate how fast inference takes, without loss calculation, which for some models can have a significant
            # speed impact
            metrics = speedtest(loader,
                                model,
                                metrics,
                                steps_per_test_epoch,
                                device=device,
                                dali=dali,
                                fp16=fp16)

        # use our metrics file to output graphs for this epoch
        viz.visualize_logger(metrics.fname, examples)

        # save a checkpoint
        utils.checkpoint(model, rundir, epoch)
        # # update latest models file
        # projects.write_latest_model(config['model'], config['flow_generator'], rundir, config)

        # input the latest validation loss to the early stopper
        if stopper.name == 'early':
            should_stop, _ = stopper(metrics.latest_key['val'])
        elif stopper.name == 'learning_rate':
            should_stop = stopper(min_lr)
        else:
            # every epoch, increment stopper
            should_stop = stopper()

        if should_stop:
            log.info('Stopping criterion reached!')
            break
    return model
def train_fn(train_loader, teacher_model, model, criterion, optimizer, epoch,
             scheduler, device):
    if CFG.device == 'GPU':
        scaler = GradScaler()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    scores = AverageMeter()
    # switch to train mode
    model.train()
    start = end = time.time()
    global_step = 0
    for step, (images, images_annot, labels) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        with torch.no_grad():
            teacher_features, _, _ = teacher_model(images_annot.to(device))
        images = images.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        if CFG.device == 'GPU':
            with autocast():
                features, _, y_preds = model(images)
                loss = criterion(teacher_features, features, y_preds, labels)
                # record loss
                losses.update(loss.item(), batch_size)
                if CFG.gradient_accumulation_steps > 1:
                    loss = loss / CFG.gradient_accumulation_steps
                scaler.scale(loss).backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), CFG.max_grad_norm)
                if (step + 1) % CFG.gradient_accumulation_steps == 0:
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
                    global_step += 1
        elif CFG.device == 'TPU':
            features, _, y_preds = model(images)
            loss = criterion(teacher_features, features, y_preds, labels)
            # record loss
            losses.update(loss.item(), batch_size)
            if CFG.gradient_accumulation_steps > 1:
                loss = loss / CFG.gradient_accumulation_steps
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       CFG.max_grad_norm)
            if (step + 1) % CFG.gradient_accumulation_steps == 0:
                xm.optimizer_step(optimizer, barrier=True)
                optimizer.zero_grad()
                global_step += 1
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if CFG.device == 'GPU':
            if step % CFG.print_freq == 0 or step == (len(train_loader) - 1):
                print('Epoch: [{0}][{1}/{2}] '
                      'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                      'Elapsed {remain:s} '
                      'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                      'Grad: {grad_norm:.4f}  '
                      #'LR: {lr:.6f}  '
                      .format(
                       epoch+1, step, len(train_loader), batch_time=batch_time,
                       data_time=data_time, loss=losses,
                       remain=timeSince(start, float(step+1)/len(train_loader)),
                       grad_norm=grad_norm,
                       #lr=scheduler.get_lr()[0],
                       ))
        elif CFG.device == 'TPU':
            if step % CFG.print_freq == 0 or step == (len(train_loader) - 1):
                xm.master_print('Epoch: [{0}][{1}/{2}] '
                                'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                                'Elapsed {remain:s} '
                                'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                                'Grad: {grad_norm:.4f}  '
                                #'LR: {lr:.6f}  '
                                .format(
                                epoch+1, step, len(train_loader), batch_time=batch_time,
                                data_time=data_time, loss=losses,
                                remain=timeSince(start, float(step+1)/len(train_loader)),
                                grad_norm=grad_norm,
                                #lr=scheduler.get_lr()[0],
                                ))
    return losses.avg
Exemplo n.º 5
0
class ModelTrainer(object):
    def __init__(self, embed_model, optimizer, scheduler, mixedprec, **kwargs):

        self.__model__ = embed_model

        ## Optimizer (e.g. Adam or SGD)
        Optimizer = importlib.import_module(
            'optimizer.' + optimizer).__getattribute__('Optimizer')
        self.__optimizer__ = Optimizer(self.__model__.parameters(), **kwargs)

        ## Learning rate scheduler
        Scheduler = importlib.import_module(
            'scheduler.' + scheduler).__getattribute__('Scheduler')
        self.__scheduler__, self.lr_step = Scheduler(self.__optimizer__,
                                                     **kwargs)

        ## For mixed precision training
        self.scaler = GradScaler()
        self.mixedprec = mixedprec

        assert self.lr_step in ['epoch', 'iteration']

    # ## ===== ===== ===== ===== ===== ===== ===== =====
    # ## Train network
    # ## ===== ===== ===== ===== ===== ===== ===== =====

    def train_network(self, loader, verbose):

        self.__model__.train()

        stepsize = loader.batch_size

        counter = 0
        index = 0
        loss = 0
        top1 = 0  # EER or accuracy

        tstart = time.time()

        for data, label in loader:

            data = data.transpose(1, 0)

            ## Reset gradients
            self.__model__.zero_grad()

            ## Forward and backward passes
            if self.mixedprec:
                with autocast():
                    nloss, prec1 = self.__model__(data.cuda(), label.cuda())
                self.scaler.scale(nloss).backward()
                self.scaler.step(self.__optimizer__)
                self.scaler.update()
            else:
                nloss, prec1 = self.__model__(data.cuda(), label.cuda())
                nloss.backward()
                self.__optimizer__.step()

            loss += nloss.detach().cpu()
            top1 += prec1.detach().cpu()
            counter += 1
            index += stepsize

            telapsed = time.time() - tstart
            tstart = time.time()

            if verbose:
                sys.stdout.write("\rProcessing (%d) " % (index))
                sys.stdout.write(
                    "Loss %f TEER/TAcc %2.3f%% - %.2f Hz " %
                    (loss / counter, top1 / counter, stepsize / telapsed))
                sys.stdout.flush()

            if self.lr_step == 'iteration': self.__scheduler__.step()

        if self.lr_step == 'epoch': self.__scheduler__.step()

        sys.stdout.write("\n")

        return (loss / counter, top1 / counter)

    ## ===== ===== ===== ===== ===== ===== ===== =====
    ## Evaluate from list
    ## ===== ===== ===== ===== ===== ===== ===== =====

    def evaluateFromList(self,
                         test_list,
                         test_path,
                         nDataLoaderThread,
                         transform,
                         print_interval=100,
                         num_eval=10,
                         **kwargs):

        self.__model__.eval()

        feats = {}
        tstart = time.time()

        ## Read all lines
        with open(test_list) as f:
            lines = f.readlines()

        ## Get a list of unique file names
        files = sum([x.strip().split(',')[-2:] for x in lines], [])
        setfiles = list(set(files))
        setfiles.sort()

        ## Define test data loader
        test_dataset = test_dataset_loader(setfiles,
                                           test_path,
                                           transform=transform,
                                           num_eval=num_eval,
                                           **kwargs)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=1,
            shuffle=False,
            num_workers=nDataLoaderThread,
            drop_last=False,
        )

        ## Extract features for every image
        for idx, data in enumerate(test_loader):
            inp1 = data[0][0].cuda()
            ref_feat = self.__model__(inp1).detach().cpu()
            feats[data[1][0]] = ref_feat
            telapsed = time.time() - tstart

            if idx % print_interval == 0:
                sys.stdout.write(
                    "\rReading %d of %d: %.2f Hz, embedding size %d" %
                    (idx, len(setfiles), idx / telapsed, ref_feat.size()[1]))

        print('')
        all_scores = []
        all_labels = []
        tstart = time.time()

        ## Read files and compute all scores
        for idx, line in enumerate(lines):

            data = line.strip().split(',')

            ref_feat = feats[data[1]]
            com_feat = feats[data[2]]

            score = F.cosine_similarity(ref_feat, com_feat)

            all_scores.append(score)
            all_labels.append(int(data[0]))

            if idx % print_interval == 0:
                telapsed = time.time() - tstart
                sys.stdout.write("\rComputing %d of %d: %.2f Hz" %
                                 (idx, len(lines), idx / telapsed))
                sys.stdout.flush()

        print('')

        return (all_scores, all_labels)

    ## ===== ===== ===== ===== ===== ===== ===== =====
    ## Save parameters
    ## ===== ===== ===== ===== ===== ===== ===== =====

    def saveParameters(self, path):

        torch.save(self.__model__.state_dict(), path)

    ## ===== ===== ===== ===== ===== ===== ===== =====
    ## Load parameters
    ## ===== ===== ===== ===== ===== ===== ===== =====

    def loadParameters(self, path):

        self_state = self.__model__.state_dict()
        loaded_state = torch.load(path)
        for name, param in loaded_state.items():
            origname = name
            if name not in self_state:
                if name not in self_state:
                    print("%s is not in the model." % origname)
                    continue

            if self_state[name].size() != loaded_state[origname].size():
                print("Wrong parameter length: %s, model: %s, loaded: %s" %
                      (origname, self_state[name].size(),
                       loaded_state[origname].size()))
                continue

            self_state[name].copy_(param)
Exemplo n.º 6
0
def main():
    opt = parse_args()
    opt.sever_name = gethostname()

    # --- CUDA setting ---
    os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_ids
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # --- Random Seed setting ---
    if opt.random_seed is None:
        opt.random_seed = random.randint(1, 10000)
    random.seed(opt.random_seed)
    np.random.seed(opt.random_seed)
    torch.manual_seed(opt.random_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(opt.random_seed)
        cudnn.deterministic = True

    # --- PATH setting ---
    save_result_dir = Path(__file__).parent / "results" / opt.experiment_name
    opt.save_model_dir = str(save_result_dir / "trained_models")
    opt.save_log_path = str(save_result_dir / "train.log")
    mkdirs(opt.save_model_dir)

    # --- Prepare DataLoader ---
    train_loader, valid_loader = get_train_val_loader(opt)
    opt.src_vocab_size = train_loader.dataset.src_vocab_size
    opt.tgt_vocab_size = train_loader.dataset.tgt_vocab_size

    # --- Prepare Model ---
    model = Transformer(
        src_vocab_size=opt.src_vocab_size,
        tgt_vocab_size=opt.tgt_vocab_size,
        max_position_num=opt.max_position_num,
        d_model=opt.d_model,
        head_num=opt.head_num,
        d_k=opt.d_k,
        d_v=opt.d_v,
        d_inner=opt.d_inner,
        layer_num=opt.layer_num,
        dropout=opt.dropout,
        shared_embedding=opt.shared_embedding,
        share_dec_input_output_embed=opt.share_dec_input_output_embed,
        init_weight=opt.init_weight,
        fused_layer_norm=opt.use_fused,
    ).to(device)

    # --- Prepare optimizer and scaler ---
    if opt.use_fused:
        from apex.optimizers import FusedAdam as Adam
    else:
        from torch.optim import Adam
    optimizer = Adam(filter(lambda x: x.requires_grad, model.parameters()),
                     betas=(0.9, 0.98),
                     eps=1e-09,
                     weight_decay=opt.weight_decay)
    scaler = GradScaler(init_scale=65536.0, enabled=opt.use_amp)

    # --- Restart setting ---
    start_cnt = 1
    steps_cnt = 0
    if opt.adapt_NMT is not None:
        ex_name, step_cnt = opt.adapt_NMT.split(',')
        saved_path = f"{Path(__file__).parent}/results/{ex_name}/trained_models/step_{step_cnt}.pth"
        saved_dict = torch.load(saved_path,
                                map_location=lambda storage, loc: storage)
        check_arguments(saved_dict["settings"], opt)
        model.load_state_dict(saved_dict["model"])
        print(f"[Info]Loading complete ({saved_path})")

    if opt.restart is not None:
        start_cnt = opt.restart + 1
        if opt.restart < 500:
            model_name = f"epoch_{opt.restart}.pth"
        else:
            model_name = f"step_{opt.restart}.pth"
        saved_path = f"{opt.save_model_dir}/{model_name}"
        saved_dict = torch.load(saved_path,
                                map_location=lambda storage, loc: storage)
        model.load_state_dict(saved_dict["model"])
        optimizer.load_state_dict(saved_dict["optimizer"])
        scaler.load_state_dict(saved_dict["scaler"])
        steps_cnt = saved_dict["steps_cnt"]
        print(f"[Info]Loading complete ({saved_path})")

    scheduler = Scheduler(
        optimizer=optimizer,
        init_lr=0.,
        end_lr=opt.end_lr,
        warmup_steps=opt.warmup_steps,
        current_steps=steps_cnt,
    )

    # --- DataParallel setting ---
    gpus = [i for i in range(len(opt.gpu_ids.split(',')))]
    if len(gpus) > 1:
        model = nn.DataParallel(model, device_ids=gpus)

    # --- Prepare trainer and validator ---
    validator = ScoreCalculator(
        model=model,
        data_loader=valid_loader,
        references=valid_loader.dataset.tgt_insts,
        bpe=opt.bpe,
        cp_avg_num=opt.check_point_average,
    )
    trainer = NMTTrainer(
        model=model,
        train_loader=train_loader,
        optimizer=optimizer,
        scaler=scaler,
        scheduler=scheduler,
        opt=opt,
        validator=validator,
    )

    # --- Train ---
    if opt.max_epoch is not None:
        trainer.train_by_epoch(start_cnt)
    else:
        trainer.train_by_step(start_cnt)
Exemplo n.º 7
0
class Trainer:
    _MASTER_ADDR = 'localhost'
    _MASTER_PORT = '12355'

    def __init__(self, experiment_dir, train_dataset_dir, valid_dataset_dir,
                 gpt2_name_or_path, init_weights_from_checkpoint,
                 worker_batch_size, data_shuffle_seed, freeze_n_layers,
                 learning_rate, n_epochs, validate_each_n_steps, warmup_ratio,
                 n_accum_steps):
        self._experiment_dir = Path(experiment_dir)
        self._train_dataset_dir = Path(train_dataset_dir)
        self._valid_dataset_dir = Path(valid_dataset_dir)
        self._gpt2_name_or_path = gpt2_name_or_path
        self._init_weights_from_checkpoint = init_weights_from_checkpoint
        self._worker_batch_size = worker_batch_size
        self._data_shuffle_seed = data_shuffle_seed
        self._freeze_n_layers = freeze_n_layers
        self._learning_rate = learning_rate
        self._n_epochs = n_epochs
        self._validate_each_n_steps = validate_each_n_steps
        self._warmup_ratio = warmup_ratio
        self._n_accum_steps = n_accum_steps or 1

        self._world_size = torch.cuda.device_count()
        self._tokenizer = load_tokenizer(self._train_dataset_dir)

        self._optimizer = None
        self._scaler = None
        self._rank = None
        self._model = None
        self._train_dl = None
        self._valid_dl = None
        self._global_step = None
        self._samples_seen = None
        self._writer = None
        self._model_params = None

        checkpoint_dir = self._experiment_dir / CHECKPOINTS_DIR_NAME
        checkpoint_dir.mkdir(exist_ok=True)
        self._checkpoint_file_path = checkpoint_dir / 'last.ckpt'
        dataset_meta = read_meta(train_dataset_dir)
        with open(self._experiment_dir / 'meta.json', 'w') as file:
            json.dump(dataset_meta, file, indent=2)

    def run(self):
        get_pretrained_gpt2_with_lm_head(self._gpt2_name_or_path)
        mp.spawn(self._train, nprocs=self._world_size, join=True)

    def _save_checkpoint(self):
        checkpoint = {
            'scaler_state_dict': self._scaler.state_dict(),
            'model_state_dict': self._model.state_dict(),
            'optimizer_state_dict': self._optimizer.state_dict(),
            'scheduler_state_dict': self._scheduler.state_dict(),
            'global_step': self._global_step,
            'samples_seen': self._samples_seen,
            'world_size': self._world_size,
            'gpt2_config_dict': self._model.module.gpt2.config.to_dict(),
            'model_params': self._model_params
        }

        torch.save(checkpoint, self._checkpoint_file_path)

    def _load_checkpoint(self):
        checkpoint = torch.load(self._checkpoint_file_path, map_location='cpu')
        checkpoint_world_size = checkpoint['world_size']
        if checkpoint_world_size != self._world_size:
            raise ValueError(
                f'Checkpoint world size {checkpoint_world_size} does not match with the current '
                f'world size {self._world_size}.')

        self._scaler.load_state_dict(checkpoint['scaler_state_dict'])
        self._model.load_state_dict(checkpoint['model_state_dict'])
        self._optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self._scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self._global_step = checkpoint['global_step']
        self._samples_seen = checkpoint['samples_seen']
        self._train_dl = self._get_dataloader(
            is_train=True, samples_offset=self._samples_seen)

    def _load_only_weights_from_checkpoint(self):
        checkpoint = torch.load(self._init_weights_from_checkpoint,
                                map_location='cpu')
        self._model.load_state_dict(checkpoint['model_state_dict'])

    def _train(self, rank):
        _seed_everything(self._data_shuffle_seed)
        self._setup_ddp(rank)
        self._rank = rank
        self._scaler = GradScaler()
        self._model = self._get_model(self._rank)
        self._optimizer = AdamW(params=self._model.parameters(),
                                lr=self._learning_rate)

        self._global_step = 0
        self._samples_seen = 0

        self._train_dl = self._get_dataloader(is_train=True, samples_offset=0)
        self._valid_dl = self._get_dataloader(is_train=False, samples_offset=0)

        steps_per_epoch = len(self._train_dl)
        num_training_steps = steps_per_epoch * self._n_epochs
        num_warmup_steps = self._warmup_ratio * num_training_steps
        self._scheduler = get_linear_schedule_with_warmup(
            optimizer=self._optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps)

        if self._checkpoint_file_path.is_file():
            self._load_checkpoint()
        elif self._init_weights_from_checkpoint:
            self._load_only_weights_from_checkpoint()

        while True:
            if self._rank == 0:
                self._writer = self._writer or SummaryWriter(
                    self._experiment_dir / 'tb_logs')
                self._train_dl = tqdm.tqdm(self._train_dl,
                                           desc='Train step',
                                           total=num_training_steps,
                                           position=1,
                                           initial=self._global_step)

            for i_step, model_input in enumerate(self._train_dl):
                train_losses_dict = self._train_step(model_input)

                if rank == 0:
                    self._train_dl.set_postfix({
                        'samples_seen':
                        self._samples_seen,
                        'epoch':
                        self._global_step / steps_per_epoch
                    })
                    self._write_tb_logs(train_losses_dict)
                    self._write_tb_logs({
                        'learning-rate':
                        self._optimizer.param_groups[0]['lr']
                    })
                    self._write_tb_logs(
                        {'max_seq_len': model_input.input_ids.size()[1]})

                if self._rank == 0 and self._global_step % self._validate_each_n_steps == 0:
                    valid_loss = self._validate()
                    self._save_checkpoint()
                    self._write_tb_logs({'loss/valid': valid_loss})

                if self._global_step >= num_training_steps:
                    break

        dist.destroy_process_group()

    def _write_tb_logs(self, values_dict):
        for tag, val in values_dict.items():
            self._writer.add_scalar(tag=tag,
                                    scalar_value=val,
                                    global_step=self._global_step)

    def _train_step(self, model_input):
        self._model.train()

        with autocast():
            model_output = self._model(model_input)
            loss = model_output.loss / self._n_accum_steps
            lm_loss = model_output.lm_loss / self._n_accum_steps
            cls_loss = model_output.cls_loss / self._n_accum_steps

        self._scaler.scale(loss).backward()

        if self._global_step % self._n_accum_steps == 0:
            self._scaler.step(self._optimizer)
            self._scaler.update()
            self._optimizer.zero_grad()

        dist.all_reduce(loss)
        dist.all_reduce(lm_loss)
        dist.all_reduce(cls_loss)

        loss = (loss.item() / self._world_size) * self._n_accum_steps
        lm_loss = (lm_loss.item() / self._world_size) * self._n_accum_steps
        cls_loss = (cls_loss.item() / self._world_size) * self._n_accum_steps

        samples_seen = torch.tensor(len(model_input.input_ids),
                                    device=self._rank)
        dist.all_reduce(samples_seen)

        self._samples_seen += samples_seen.item()
        self._global_step += 1
        self._scheduler.step()

        losses_dict = {
            'loss/train': loss,
            'lm_loss/train': lm_loss,
            'cls_loss/train': cls_loss
        }

        return losses_dict

    def _setup_ddp(self, rank):
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12355'
        dist.init_process_group("nccl", rank=rank, world_size=self._world_size)

    def _get_model(self, rank):
        model_params = {
            'gpt2_name_or_path': self._gpt2_name_or_path,
            'vocab_size': self._tokenizer.vocab_size,
            'n_classes': 2,
            'end_of_speaker_2_token_id':
            self._tokenizer.end_of_speaker_2_token_id,
            'cls_loss_weight': 0.25
        }

        model = DialogModel(**model_params)
        model = model.to(rank)
        model = DistributedDataParallel(model, device_ids=[rank])
        self._model_params = model_params

        return model

    def _get_dataloader(self, is_train, samples_offset):
        return get_dataloader(dataset_dir=self._train_dataset_dir
                              if is_train else self._valid_dataset_dir,
                              distractor_p=0.5,
                              batch_size=self._worker_batch_size,
                              num_workers=2,
                              sort_chunk_size=self._worker_batch_size * 10,
                              samples_offset=samples_offset,
                              data_shuffle_seed=self._data_shuffle_seed,
                              is_distributed=is_train,
                              pad_token_id=self._tokenizer.pad_token_id,
                              end_of_speaker_1_token_id=self._tokenizer.
                              end_of_speaker_1_token_id,
                              end_of_speaker_2_token_id=self._tokenizer.
                              end_of_speaker_2_token_id)

    @torch.no_grad()
    def _validate(self):
        self._model.eval()

        loss = 0
        valid_dl = tqdm.tqdm(self._valid_dl,
                             desc='Valid step',
                             total=len(self._valid_dl),
                             position=2)
        for model_input in valid_dl:
            with autocast():
                model_output = self._model(model_input)
                loss_on_step = model_output.loss
                loss += loss_on_step.item()

        loss /= len(valid_dl)

        return loss
Exemplo n.º 8
0
class Trainer(BaseTrainer):
    def __init__(self, config):
        super(Trainer, self).__init__(config)
        self.datamanager = DataManger(config["data"])

        # model
        self.model = Baseline(
            num_classes=self.datamanager.datasource.get_num_classes("train")
        )

        # summary model
        summary(
            self.model,
            input_size=(3, 256, 128),
            batch_size=config["data"]["batch_size"],
            device="cpu",
        )

        # losses
        cfg_losses = config["losses"]
        self.criterion = Softmax_Triplet_loss(
            num_class=self.datamanager.datasource.get_num_classes("train"),
            margin=cfg_losses["margin"],
            epsilon=cfg_losses["epsilon"],
            use_gpu=self.use_gpu,
        )

        self.center_loss = CenterLoss(
            num_classes=self.datamanager.datasource.get_num_classes("train"),
            feature_dim=2048,
            use_gpu=self.use_gpu,
        )

        # optimizer
        cfg_optimizer = config["optimizer"]
        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=cfg_optimizer["lr"],
            weight_decay=cfg_optimizer["weight_decay"],
        )

        self.optimizer_centerloss = torch.optim.SGD(
            self.center_loss.parameters(), lr=0.5
        )

        # learing rate scheduler
        cfg_lr_scheduler = config["lr_scheduler"]
        self.lr_scheduler = WarmupMultiStepLR(
            self.optimizer,
            milestones=cfg_lr_scheduler["steps"],
            gamma=cfg_lr_scheduler["gamma"],
            warmup_factor=cfg_lr_scheduler["factor"],
            warmup_iters=cfg_lr_scheduler["iters"],
            warmup_method=cfg_lr_scheduler["method"],
        )

        # track metric
        self.train_metrics = MetricTracker("loss", "accuracy")
        self.valid_metrics = MetricTracker("loss", "accuracy")

        # save best accuracy for function _save_checkpoint
        self.best_accuracy = None

        # send model to device
        self.model.to(self.device)

        self.scaler = GradScaler()

        # resume model from last checkpoint
        if config["resume"] != "":
            self._resume_checkpoint(config["resume"])

    def train(self):
        for epoch in range(self.start_epoch, self.epochs + 1):
            result = self._train_epoch(epoch)

            if self.lr_scheduler is not None:
                self.lr_scheduler.step()

            result = self._valid_epoch(epoch)

            # add scalars to tensorboard
            self.writer.add_scalars(
                "Loss",
                {
                    "Train": self.train_metrics.avg("loss"),
                    "Val": self.valid_metrics.avg("loss"),
                },
                global_step=epoch,
            )
            self.writer.add_scalars(
                "Accuracy",
                {
                    "Train": self.train_metrics.avg("accuracy"),
                    "Val": self.valid_metrics.avg("accuracy"),
                },
                global_step=epoch,
            )

            # logging result to console
            log = {"epoch": epoch}
            log.update(result)
            for key, value in log.items():
                self.logger.info("    {:15s}: {}".format(str(key), value))

            # save model
            if (
                self.best_accuracy == None
                or self.best_accuracy < self.valid_metrics.avg("accuracy")
            ):
                self.best_accuracy = self.valid_metrics.avg("accuracy")
                self._save_checkpoint(epoch, save_best=True)
            else:
                self._save_checkpoint(epoch, save_best=False)

            # save logs
            self._save_logs(epoch)

    def _train_epoch(self, epoch):
        """Training step"""
        self.model.train()
        self.train_metrics.reset()
        with tqdm(total=len(self.datamanager.get_dataloader("train"))) as epoch_pbar:
            epoch_pbar.set_description(f"Epoch {epoch}")
            for batch_idx, (data, labels, _) in enumerate(
                self.datamanager.get_dataloader("train")
            ):
                # push data to device
                data, labels = data.to(self.device), labels.to(self.device)

                # zero gradient
                self.optimizer.zero_grad()
                self.optimizer_centerloss.zero_grad()

                with autocast():
                    # forward batch
                    score, feat = self.model(data)

                    # calculate loss and accuracy
                    loss = (
                        self.criterion(score, feat, labels)
                        + self.center_loss(feat, labels) * self.config["losses"]["beta"]
                    )
                    _, preds = torch.max(score.data, dim=1)

                # backward parameters
                # loss.backward()
                self.scaler.scale(loss).backward()

                # backward parameters for center_loss
                for param in self.center_loss.parameters():
                    param.grad.data *= 1.0 / self.config["losses"]["beta"]

                # optimize
                # self.optimizer.step()
                self.scaler.step(self.optimizer)
                self.optimizer_centerloss.step()

                self.scaler.update()

                # update loss and accuracy in MetricTracker
                self.train_metrics.update("loss", loss.item())
                self.train_metrics.update(
                    "accuracy",
                    torch.sum(preds == labels.data).double().item() / data.size(0),
                )

                # update process bar
                epoch_pbar.set_postfix(
                    {
                        "train_loss": self.train_metrics.avg("loss"),
                        "train_acc": self.train_metrics.avg("accuracy"),
                    }
                )
                epoch_pbar.update(1)
        return self.train_metrics.result()

    def _valid_epoch(self, epoch):
        """Validation step"""
        self.model.eval()
        self.valid_metrics.reset()
        with torch.no_grad():
            with tqdm(total=len(self.datamanager.get_dataloader("val"))) as epoch_pbar:
                epoch_pbar.set_description(f"Epoch {epoch}")
                for batch_idx, (data, labels, _) in enumerate(
                    self.datamanager.get_dataloader("val")
                ):
                    # push data to device
                    data, labels = data.to(self.device), labels.to(self.device)

                    with autocast():
                        # forward batch
                        score, feat = self.model(data)

                        # calculate loss and accuracy
                        loss = (
                            self.criterion(score, feat, labels)
                            + self.center_loss(feat, labels)
                            * self.config["losses"]["beta"]
                        )
                        _, preds = torch.max(score.data, dim=1)

                    # update loss and accuracy in MetricTracker
                    self.valid_metrics.update("loss", loss.item())
                    self.valid_metrics.update(
                        "accuracy",
                        torch.sum(preds == labels.data).double().item() / data.size(0),
                    )

                    # update process bar
                    epoch_pbar.set_postfix(
                        {
                            "val_loss": self.valid_metrics.avg("loss"),
                            "val_acc": self.valid_metrics.avg("accuracy"),
                        }
                    )
                    epoch_pbar.update(1)
        return self.valid_metrics.result()

    def _save_checkpoint(self, epoch, save_best=True):
        """save model to file"""
        state = {
            "epoch": epoch,
            "state_dict": self.model.state_dict(),
            "center_loss": self.center_loss.state_dict(),
            "optimizer": self.optimizer.state_dict(),
            "optimizer_centerloss": self.optimizer_centerloss.state_dict(),
            "lr_scheduler": self.lr_scheduler.state_dict(),
            "best_accuracy": self.best_accuracy,
        }
        filename = os.path.join(self.checkpoint_dir, "model_last.pth")
        self.logger.info("Saving last model: model_last.pth ...")
        torch.save(state, filename)
        if save_best:
            filename = os.path.join(self.checkpoint_dir, "model_best.pth")
            self.logger.info("Saving current best: model_best.pth ...")
            torch.save(state, filename)

    def _resume_checkpoint(self, resume_path):
        """Load model from checkpoint"""
        if not os.path.exists(resume_path):
            raise FileExistsError("Resume path not exist!")
        self.logger.info("Loading checkpoint: {} ...".format(resume_path))
        checkpoint = torch.load(resume_path, map_location=self.map_location)
        self.start_epoch = checkpoint["epoch"] + 1
        self.model.load_state_dict(checkpoint["state_dict"])
        self.center_loss.load_state_dict(checkpoint["center_loss"])
        self.optimizer.load_state_dict(checkpoint["optimizer"])
        self.optimizer_centerloss.load_state_dict(checkpoint["optimizer_centerloss"])
        self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
        self.best_accuracy = checkpoint["best_accuracy"]
        self.logger.info(
            "Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)
        )

    def _save_logs(self, epoch):
        """Save logs from google colab to google drive"""
        if os.path.isdir(self.logs_dir_saved):
            shutil.rmtree(self.logs_dir_saved)
        destination = shutil.copytree(self.logs_dir, self.logs_dir_saved)
Exemplo n.º 9
0
    def __init__(self, config):
        super(Trainer, self).__init__(config)
        self.datamanager = DataManger(config["data"])

        # model
        self.model = Baseline(
            num_classes=self.datamanager.datasource.get_num_classes("train")
        )

        # summary model
        summary(
            self.model,
            input_size=(3, 256, 128),
            batch_size=config["data"]["batch_size"],
            device="cpu",
        )

        # losses
        cfg_losses = config["losses"]
        self.criterion = Softmax_Triplet_loss(
            num_class=self.datamanager.datasource.get_num_classes("train"),
            margin=cfg_losses["margin"],
            epsilon=cfg_losses["epsilon"],
            use_gpu=self.use_gpu,
        )

        self.center_loss = CenterLoss(
            num_classes=self.datamanager.datasource.get_num_classes("train"),
            feature_dim=2048,
            use_gpu=self.use_gpu,
        )

        # optimizer
        cfg_optimizer = config["optimizer"]
        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=cfg_optimizer["lr"],
            weight_decay=cfg_optimizer["weight_decay"],
        )

        self.optimizer_centerloss = torch.optim.SGD(
            self.center_loss.parameters(), lr=0.5
        )

        # learing rate scheduler
        cfg_lr_scheduler = config["lr_scheduler"]
        self.lr_scheduler = WarmupMultiStepLR(
            self.optimizer,
            milestones=cfg_lr_scheduler["steps"],
            gamma=cfg_lr_scheduler["gamma"],
            warmup_factor=cfg_lr_scheduler["factor"],
            warmup_iters=cfg_lr_scheduler["iters"],
            warmup_method=cfg_lr_scheduler["method"],
        )

        # track metric
        self.train_metrics = MetricTracker("loss", "accuracy")
        self.valid_metrics = MetricTracker("loss", "accuracy")

        # save best accuracy for function _save_checkpoint
        self.best_accuracy = None

        # send model to device
        self.model.to(self.device)

        self.scaler = GradScaler()

        # resume model from last checkpoint
        if config["resume"] != "":
            self._resume_checkpoint(config["resume"])
Exemplo n.º 10
0
    def __init__(self, opt):
        super(inpaintModel, self).__init__(opt)

        self.counter = 0

        train_opt = opt['train']

        # set if data should be normalized (-1,1) or not (0,1)
        if self.is_train:
            z_norm = opt['datasets']['train'].get('znorm', False)

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netG.train()
            if train_opt['gan_weight']:
                self.netD = networks.define_D(opt).to(self.device)  # D
                self.netD.train()
        self.load()  # load G and D if needed
        self.which_model_G = opt['network_G']['which_model_G']

        # define losses, optimizer and scheduler
        if self.is_train:
            """
            Setup network cap
            """
            # define if the generator will have a final capping mechanism in the output
            self.outm = train_opt.get('finalcap', None)
            """
            Setup batch augmentations
            """
            self.mixup = train_opt.get('mixup', None)
            if self.mixup:
                #TODO: cutblur and cutout need model to be modified so LR and HR have the same dimensions (1x)
                self.mixopts = train_opt.get(
                    'mixopts', ["blend", "rgb", "mixup", "cutmix", "cutmixup"
                                ])  #, "cutout", "cutblur"]
                self.mixprob = train_opt.get(
                    'mixprob', [1.0, 1.0, 1.0, 1.0, 1.0])  #, 1.0, 1.0]
                self.mixalpha = train_opt.get(
                    'mixalpha', [0.6, 1.0, 1.2, 0.7, 0.7])  #, 0.001, 0.7]
                self.aux_mixprob = train_opt.get('aux_mixprob', 1.0)
                self.aux_mixalpha = train_opt.get('aux_mixalpha', 1.2)
                self.mix_p = train_opt.get('mix_p', None)
            """
            Setup frequency separation
            """
            self.fs = train_opt.get('fs', None)
            self.f_low = None
            self.f_high = None
            if self.fs:
                lpf_type = train_opt.get('lpf_type', "average")
                hpf_type = train_opt.get('hpf_type', "average")
                self.f_low = FilterLow(filter_type=lpf_type).to(self.device)
                self.f_high = FilterHigh(filter_type=hpf_type).to(self.device)
            """
            Initialize losses
            """
            #Initialize the losses with the opt parameters
            # Generator losses:
            self.generatorlosses = losses.GeneratorLoss(opt, self.device)
            # TODO: show the configured losses names in logger
            # print(self.generatorlosses.loss_list)

            # Discriminator loss:
            if train_opt['gan_type'] and train_opt['gan_weight']:
                self.cri_gan = True
                diffaug = train_opt.get('diffaug', None)
                dapolicy = None
                if diffaug:  #TODO: this if should not be necessary
                    dapolicy = train_opt.get(
                        'dapolicy', 'color,translation,cutout')  #original
                self.adversarial = losses.Adversarial(train_opt=train_opt,
                                                      device=self.device,
                                                      diffaug=diffaug,
                                                      dapolicy=dapolicy)
                # D_update_ratio and D_init_iters are for WGAN
                self.D_update_ratio = train_opt.get('D_update_ratio', 1)
                self.D_init_iters = train_opt.get('D_init_iters', 0)
            else:
                self.cri_gan = False
            """
            Prepare optimizers
            """
            self.optGstep = False
            self.optDstep = False
            if self.cri_gan:
                self.optimizers, self.optimizer_G, self.optimizer_D = optimizers.get_optimizers(
                    self.cri_gan, self.netD, self.netG, train_opt, logger,
                    self.optimizers)
            else:
                self.optimizers, self.optimizer_G = optimizers.get_optimizers(
                    None, None, self.netG, train_opt, logger, self.optimizers)
                self.optDstep = True
            """
            Prepare schedulers
            """
            self.schedulers = schedulers.get_schedulers(
                optimizers=self.optimizers,
                schedulers=self.schedulers,
                train_opt=train_opt)

            #Keep log in loss class instead?
            self.log_dict = OrderedDict()
            """
            Configure SWA
            """
            #https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/
            self.swa = opt.get('use_swa', False)
            if self.swa:
                self.swa_start_iter = train_opt.get('swa_start_iter', 0)
                # self.swa_start_epoch = train_opt.get('swa_start_epoch', None)
                swa_lr = train_opt.get('swa_lr', 0.0001)
                swa_anneal_epochs = train_opt.get('swa_anneal_epochs', 10)
                swa_anneal_strategy = train_opt.get('swa_anneal_strategy',
                                                    'cos')
                #TODO: Note: This could be done in resume_training() instead, to prevent creating
                # the swa scheduler and model before they are needed
                self.swa_scheduler, self.swa_model = swa.get_swa(
                    self.optimizer_G, self.netG, swa_lr, swa_anneal_epochs,
                    swa_anneal_strategy)
                self.load_swa()  #load swa from resume state
                logger.info('SWA enabled. Starting on iter: {}, lr: {}'.format(
                    self.swa_start_iter, swa_lr))
            """
            If using virtual batch
            """
            batch_size = opt["datasets"]["train"]["batch_size"]
            virtual_batch = opt["datasets"]["train"].get(
                'virtual_batch_size', None)
            self.virtual_batch = virtual_batch if virtual_batch \
                >= batch_size else batch_size
            self.accumulations = self.virtual_batch // batch_size
            self.optimizer_G.zero_grad()
            if self.cri_gan:
                self.optimizer_D.zero_grad()
            """
            Configure AMP
            """
            self.amp = load_amp and opt.get('use_amp', False)
            if self.amp:
                self.cast = autocast
                self.amp_scaler = GradScaler()
                logger.info('AMP enabled')
            else:
                self.cast = nullcast

        # print network
        """
Exemplo n.º 11
0
def main(local_rank):
    dist.init_process_group(backend='nccl', init_method='env://')
    cfg.local_rank = local_rank
    torch.cuda.set_device(local_rank)
    cfg.rank = dist.get_rank()
    cfg.world_size = dist.get_world_size()
    print("world_size:", cfg.world_size)
    trainset = MXFaceDataset(root_dir=cfg.rec, local_rank=local_rank)
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        trainset, shuffle=True)
    train_loader = DataLoaderX(local_rank=local_rank,
                               dataset=trainset,
                               batch_size=cfg.batch_size,
                               sampler=train_sampler,
                               num_workers=0,
                               pin_memory=True,
                               drop_last=False)

    backbone = backbones.iefficientnet(False, num_features=cfg.embedding_size)
    backbone.train()
    # Broadcast init parameters

    if cfg.preload:
        print("Loading Backbone Checkpoint '{}'".format(cfg.preload))
        loc = 'cuda:{}'.format(cfg.local_rank)
        backbone.load_state_dict(torch.load(cfg.preload, map_location=loc))
    backbone = backbone.to(local_rank)

    backbone.train()
    for ps in backbone.parameters():
        dist.broadcast(ps, 0)

    # DDP
    backbone = torch.nn.parallel.DistributedDataParallel(
        module=backbone, broadcast_buffers=False, device_ids=[cfg.local_rank])
    backbone.train()
    # Memory classifer
    dist_sample_classifer = DistSampleClassifier(rank=dist.get_rank(),
                                                 local_rank=local_rank,
                                                 world_size=cfg.world_size)
    if cfg.pf_preload:
        file_model = os.path.join(cfg.pf_preload,
                                  str(cfg.local_rank) + '_pf.pth')
        print("Loading pf Checkpoint '{}'".format(file_model))
        loc = 'cuda:{}'.format(cfg.local_rank)
        dist_sample_classifer.load_state_dict(
            torch.load(file_model, map_location=loc))
    dist_sample_classifer = dist_sample_classifer.to(local_rank)

    # Margin softmax
    margin_softmax = MarginSoftmax(s=64.0, m=0.4)

    # Optimizer for backbone and classifer
    optimizer = SGD(
        [
            {
                'params': backbone.parameters()
            },
            {
                'params': dist_sample_classifer.parameters(),
                'lr': cfg.lr
                # 'params': dist_sample_classifer.parameters()
            }
        ],
        lr=cfg.lr,
        momentum=cfg.momentum,
        weight_decay=cfg.weight_decay,
        rescale=cfg.world_size)

    # Lr scheduler
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer,
                                                  lr_lambda=cfg.lr_func)
    n_epochs = cfg.num_epoch
    start_epoch = 0

    if local_rank == 0:
        writer = SummaryWriter(log_dir='logs/shows')

    #
    total_step = int(
        len(trainset) / cfg.batch_size / dist.get_world_size() * cfg.num_epoch)
    if dist.get_rank() == 0:
        print("Total Step is: %d" % total_step)

    losses = AverageMeter()
    global_step = 0
    train_start = time.time()
    if cfg.fp16:
        scaler = GradScaler()

    for epoch in range(start_epoch, n_epochs):
        train_sampler.set_epoch(epoch)
        for step, (img, label) in enumerate(train_loader):
            # total_label, norm_weight = dist_sample_classifer.prepare(
            #     label, optimizer)
            if cfg.fp16:
                with autocast():
                    total_label, norm_weight = dist_sample_classifer.prepare(
                        label, optimizer)
                    features = F.normalize(backbone(img))

                    # Features all-gather
                    total_features = torch.zeros(features.size()[0] *
                                                 cfg.world_size,
                                                 cfg.embedding_size,
                                                 device=local_rank)
                    dist.all_gather(
                        list(total_features.chunk(cfg.world_size, dim=0)),
                        features.data)
                    total_features.requires_grad = True

                    # Calculate logits
                    logits = dist_sample_classifer(total_features, norm_weight)
                    logits = margin_softmax(logits, total_label)

                    with torch.no_grad():
                        max_fc = torch.max(logits, dim=1, keepdim=True)[0]
                        dist.all_reduce(max_fc, dist.ReduceOp.MAX)

                        # Calculate exp(logits) and all-reduce
                        logits_exp = torch.exp(logits - max_fc)
                        logits_sum_exp = logits_exp.sum(dim=1, keepdims=True)
                        dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)

                        # Calculate prob
                        logits_exp.div_(logits_sum_exp)

                        # Get one-hot
                        grad = logits_exp
                        index = torch.where(total_label != -1)[0]
                        one_hot = torch.zeros(index.size()[0],
                                              grad.size()[1],
                                              device=grad.device)
                        one_hot.scatter_(1, total_label[index, None], 1)

                        # Calculate loss
                        loss = torch.zeros(grad.size()[0],
                                           1,
                                           device=grad.device)
                        loss[index] = grad[index].gather(
                            1, total_label[index, None])
                        dist.all_reduce(loss, dist.ReduceOp.SUM)
                        loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)

                        # Calculate grad
                        grad[index] -= one_hot
                        grad.div_(features.size()[0])
            else:
                total_label, norm_weight = dist_sample_classifer.prepare(
                    label, optimizer)
                features = F.normalize(backbone(img))
                # Features all-gather
                total_features = torch.zeros(features.size()[0] *
                                             cfg.world_size,
                                             cfg.embedding_size,
                                             device=local_rank)
                dist.all_gather(
                    list(total_features.chunk(cfg.world_size, dim=0)),
                    features.data)
                total_features.requires_grad = True

                # Calculate logits
                logits = dist_sample_classifer(total_features, norm_weight)
                logits = margin_softmax(logits, total_label)

                with torch.no_grad():
                    max_fc = torch.max(logits, dim=1, keepdim=True)[0]
                    dist.all_reduce(max_fc, dist.ReduceOp.MAX)

                    # Calculate exp(logits) and all-reduce
                    logits_exp = torch.exp(logits - max_fc)
                    logits_sum_exp = logits_exp.sum(dim=1, keepdims=True)
                    dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)

                    # Calculate prob
                    logits_exp.div_(logits_sum_exp)

                    # Get one-hot
                    grad = logits_exp
                    index = torch.where(total_label != -1)[0]
                    one_hot = torch.zeros(index.size()[0],
                                          grad.size()[1],
                                          device=grad.device)
                    one_hot.scatter_(1, total_label[index, None], 1)

                    # Calculate loss
                    loss = torch.zeros(grad.size()[0], 1, device=grad.device)
                    loss[index] = grad[index].gather(1, total_label[index,
                                                                    None])
                    dist.all_reduce(loss, dist.ReduceOp.SUM)
                    loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)

                    # Calculate grad
                    grad[index] -= one_hot
                    grad.div_(features.size()[0])

            if cfg.fp16:
                scaler.scale(logits).backward(grad)
            else:
                logits.backward(grad)
            if total_features.grad is not None:
                total_features.grad.detach_()
            x_grad = torch.zeros_like(features)

            # Feature gradient all-reduce
            dist.reduce_scatter(
                x_grad, list(total_features.grad.chunk(cfg.world_size, dim=0)))
            x_grad.mul_(cfg.world_size)
            # Backward backbone
            if cfg.fp16:
                scaler.scale(features).backward(x_grad)
                scaler.step(optimizer)

            else:
                features.backward(x_grad)
                optimizer.step()

            # Update classifer
            dist_sample_classifer.update()
            optimizer.zero_grad()
            # losses.update(loss_v, 1)
            # if cfg.fp16:
            #     scaler.update()

            losses.update(loss_v, 1)
            if cfg.local_rank == 0 and step % 50 == 0:
                time_now = (time.time() - train_start) / 3600
                time_total = time_now / ((global_step + 1) / total_step)
                time_for_end = time_total - time_now
                writer.add_scalar('time_for_end', time_for_end, global_step)
                writer.add_scalar('loss', loss_v, global_step)
                writer.add_scalar('loss_epoch', loss_v, scheduler.get_lr()[0])
                # writer.add_scalar('loss_lr', loss_v, scheduler.get_lr()[0])
                print(
                    "Speed %d samples/sec   Loss %.4f   Epoch: %d   Global Step: %d lr: %0.6f   Required: %1.f hours"
                    % ((cfg.batch_size * global_step /
                        (time.time() - train_start) * cfg.world_size),
                       losses.avg, epoch, global_step, scheduler.get_lr()[0],
                       time_for_end))
                losses.reset()
            if cfg.fp16:
                scaler.update()

            global_step += 1
        scheduler.step()
        if dist.get_rank() == 0:
            if not os.path.exists(cfg.output):
                os.makedirs(cfg.output)
            torch.save(backbone.module.state_dict(),
                       os.path.join(cfg.output,
                                    str(epoch) + 'backbone.pth'))
        #save pf
        for gpu_index in range(cfg.world_size):
            if not os.path.exists(cfg.output):
                os.makedirs(cfg.output)
            torch.save(
                dist_sample_classifer.state_dict(),
                os.path.join(cfg.output,
                             str(dist.get_rank()) + '_pf.pth'))

    dist.destroy_process_group()
Exemplo n.º 12
0
class inpaintModel(BaseModel):
    def __init__(self, opt):
        super(inpaintModel, self).__init__(opt)

        self.counter = 0

        train_opt = opt['train']

        # set if data should be normalized (-1,1) or not (0,1)
        if self.is_train:
            z_norm = opt['datasets']['train'].get('znorm', False)

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netG.train()
            if train_opt['gan_weight']:
                self.netD = networks.define_D(opt).to(self.device)  # D
                self.netD.train()
        self.load()  # load G and D if needed
        self.which_model_G = opt['network_G']['which_model_G']

        # define losses, optimizer and scheduler
        if self.is_train:
            """
            Setup network cap
            """
            # define if the generator will have a final capping mechanism in the output
            self.outm = train_opt.get('finalcap', None)
            """
            Setup batch augmentations
            """
            self.mixup = train_opt.get('mixup', None)
            if self.mixup:
                #TODO: cutblur and cutout need model to be modified so LR and HR have the same dimensions (1x)
                self.mixopts = train_opt.get(
                    'mixopts', ["blend", "rgb", "mixup", "cutmix", "cutmixup"
                                ])  #, "cutout", "cutblur"]
                self.mixprob = train_opt.get(
                    'mixprob', [1.0, 1.0, 1.0, 1.0, 1.0])  #, 1.0, 1.0]
                self.mixalpha = train_opt.get(
                    'mixalpha', [0.6, 1.0, 1.2, 0.7, 0.7])  #, 0.001, 0.7]
                self.aux_mixprob = train_opt.get('aux_mixprob', 1.0)
                self.aux_mixalpha = train_opt.get('aux_mixalpha', 1.2)
                self.mix_p = train_opt.get('mix_p', None)
            """
            Setup frequency separation
            """
            self.fs = train_opt.get('fs', None)
            self.f_low = None
            self.f_high = None
            if self.fs:
                lpf_type = train_opt.get('lpf_type', "average")
                hpf_type = train_opt.get('hpf_type', "average")
                self.f_low = FilterLow(filter_type=lpf_type).to(self.device)
                self.f_high = FilterHigh(filter_type=hpf_type).to(self.device)
            """
            Initialize losses
            """
            #Initialize the losses with the opt parameters
            # Generator losses:
            self.generatorlosses = losses.GeneratorLoss(opt, self.device)
            # TODO: show the configured losses names in logger
            # print(self.generatorlosses.loss_list)

            # Discriminator loss:
            if train_opt['gan_type'] and train_opt['gan_weight']:
                self.cri_gan = True
                diffaug = train_opt.get('diffaug', None)
                dapolicy = None
                if diffaug:  #TODO: this if should not be necessary
                    dapolicy = train_opt.get(
                        'dapolicy', 'color,translation,cutout')  #original
                self.adversarial = losses.Adversarial(train_opt=train_opt,
                                                      device=self.device,
                                                      diffaug=diffaug,
                                                      dapolicy=dapolicy)
                # D_update_ratio and D_init_iters are for WGAN
                self.D_update_ratio = train_opt.get('D_update_ratio', 1)
                self.D_init_iters = train_opt.get('D_init_iters', 0)
            else:
                self.cri_gan = False
            """
            Prepare optimizers
            """
            self.optGstep = False
            self.optDstep = False
            if self.cri_gan:
                self.optimizers, self.optimizer_G, self.optimizer_D = optimizers.get_optimizers(
                    self.cri_gan, self.netD, self.netG, train_opt, logger,
                    self.optimizers)
            else:
                self.optimizers, self.optimizer_G = optimizers.get_optimizers(
                    None, None, self.netG, train_opt, logger, self.optimizers)
                self.optDstep = True
            """
            Prepare schedulers
            """
            self.schedulers = schedulers.get_schedulers(
                optimizers=self.optimizers,
                schedulers=self.schedulers,
                train_opt=train_opt)

            #Keep log in loss class instead?
            self.log_dict = OrderedDict()
            """
            Configure SWA
            """
            #https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/
            self.swa = opt.get('use_swa', False)
            if self.swa:
                self.swa_start_iter = train_opt.get('swa_start_iter', 0)
                # self.swa_start_epoch = train_opt.get('swa_start_epoch', None)
                swa_lr = train_opt.get('swa_lr', 0.0001)
                swa_anneal_epochs = train_opt.get('swa_anneal_epochs', 10)
                swa_anneal_strategy = train_opt.get('swa_anneal_strategy',
                                                    'cos')
                #TODO: Note: This could be done in resume_training() instead, to prevent creating
                # the swa scheduler and model before they are needed
                self.swa_scheduler, self.swa_model = swa.get_swa(
                    self.optimizer_G, self.netG, swa_lr, swa_anneal_epochs,
                    swa_anneal_strategy)
                self.load_swa()  #load swa from resume state
                logger.info('SWA enabled. Starting on iter: {}, lr: {}'.format(
                    self.swa_start_iter, swa_lr))
            """
            If using virtual batch
            """
            batch_size = opt["datasets"]["train"]["batch_size"]
            virtual_batch = opt["datasets"]["train"].get(
                'virtual_batch_size', None)
            self.virtual_batch = virtual_batch if virtual_batch \
                >= batch_size else batch_size
            self.accumulations = self.virtual_batch // batch_size
            self.optimizer_G.zero_grad()
            if self.cri_gan:
                self.optimizer_D.zero_grad()
            """
            Configure AMP
            """
            self.amp = load_amp and opt.get('use_amp', False)
            if self.amp:
                self.cast = autocast
                self.amp_scaler = GradScaler()
                logger.info('AMP enabled')
            else:
                self.cast = nullcast

        # print network
        """
        TODO:
        Network summary? Make optional with parameter
            could be an selector between traditional print_network() and summary()
        """
        #self.print_network() #TODO

    #https://github.com/Yukariin/DFNet/blob/master/data.py
    def random_mask(self,
                    height=256,
                    width=256,
                    min_stroke=1,
                    max_stroke=4,
                    min_vertex=1,
                    max_vertex=12,
                    min_brush_width_divisor=16,
                    max_brush_width_divisor=10):

        mask = np.ones((height, width))

        min_brush_width = height // min_brush_width_divisor
        max_brush_width = height // max_brush_width_divisor
        max_angle = 2 * np.pi
        num_stroke = np.random.randint(min_stroke, max_stroke + 1)
        average_length = np.sqrt(height * height + width * width) / 8

        for _ in range(num_stroke):
            num_vertex = np.random.randint(min_vertex, max_vertex + 1)
            start_x = np.random.randint(width)
            start_y = np.random.randint(height)

            for _ in range(num_vertex):
                angle = np.random.uniform(max_angle)
                length = np.clip(
                    np.random.normal(average_length, average_length // 2), 0,
                    2 * average_length)
                brush_width = np.random.randint(min_brush_width,
                                                max_brush_width + 1)
                end_x = (start_x + length * np.sin(angle)).astype(np.int32)
                end_y = (start_y + length * np.cos(angle)).astype(np.int32)

                cv2.line(mask, (start_y, start_x), (end_y, end_x), 0.,
                         brush_width)

                start_x, start_y = end_x, end_y
        if np.random.random() < 0.5:
            mask = np.fliplr(mask)
        if np.random.random() < 0.5:
            mask = np.flipud(mask)
        return torch.from_numpy(
            mask.reshape((1, ) + mask.shape).astype(np.float32)).unsqueeze(0)

    def masking_images(self):
        mask = self.random_mask(height=self.var_L.shape[2],
                                width=self.var_L.shape[3]).cuda()
        for i in range(self.var_L.shape[0] - 1):
            mask = torch.cat([
                mask,
                self.random_mask(height=self.var_L.shape[2],
                                 width=self.var_L.shape[3]).cuda()
            ],
                             dim=0)

        #self.var_L=self.var_L * mask
        return self.var_L * mask, mask

    def masking_images_with_invert(self):
        mask = self.random_mask(height=self.var_L.shape[2],
                                width=self.var_L.shape[3]).cuda()
        for i in range(self.var_L.shape[0] - 1):
            mask = torch.cat([
                mask,
                self.random_mask(height=self.var_L.shape[2],
                                 width=self.var_L.shape[3]).cuda()
            ],
                             dim=0)

        #self.var_L=self.var_L * mask
        return self.var_L * mask, self.var_L * (1 - mask), mask

    def feed_data(self, data, need_HR=True):
        # LR images
        if self.which_model_G == 'EdgeConnect' or self.which_model_G == 'PRVS':
            self.var_L = data['LR'].to(self.device)
            self.canny_data = data['img_HR_canny'].to(self.device)
            self.grayscale_data = data['img_HR_gray'].to(self.device)
            #self.mask = data['green_mask'].to(self.device)
        else:
            self.var_L = data['LR'].to(self.device)

        if need_HR:  # train or val
            # HR images
            self.var_H = data['HR'].to(self.device)
            # discriminator references
            input_ref = data.get('ref', data['HR'])
            self.var_ref = input_ref.to(self.device)

    def feed_data_batch(self, data, need_HR=True):
        # LR
        self.var_L = data

    def optimize_parameters(self, step):
        # G
        # freeze discriminator while generator is trained to prevent BP
        if self.cri_gan:
            for p in self.netD.parameters():
                p.requires_grad = False

        # batch (mixup) augmentations
        aug = None
        if self.mixup:
            self.var_H, self.var_L, mask, aug = BatchAug(
                self.var_H, self.var_L, self.mixopts, self.mixprob,
                self.mixalpha, self.aux_mixprob, self.aux_mixalpha, self.mix_p)

        if self.which_model_G == 'Pluralistic':
            # pluralistic needs the inpainted area as an image and not only the cut-out
            self.var_L, img_inverted, mask = self.masking_images_with_invert()
        else:
            self.var_L, mask = self.masking_images()

        ### Network forward, generate inpainted fake
        with self.cast():
            # normal
            if self.which_model_G == 'AdaFill' or self.which_model_G == 'MEDFE' or self.which_model_G == 'RFR' or self.which_model_G == 'LBAM' or self.which_model_G == 'DMFN' or self.which_model_G == 'partial' or self.which_model_G == 'Adaptive' or self.which_model_G == 'DFNet' or self.which_model_G == 'RN':
                self.fake_H = self.netG(self.var_L, mask)
            # 2 rgb images
            elif self.which_model_G == 'CRA' or self.which_model_G == 'pennet' or self.which_model_G == 'deepfillv1' or self.which_model_G == 'deepfillv2' or self.which_model_G == 'Global' or self.which_model_G == 'crfill' or self.which_model_G == 'DeepDFNet':
                self.fake_H, self.other_img = self.netG(self.var_L, mask)

            # special
            elif self.which_model_G == 'Pluralistic':
                self.fake_H, self.kl_rec, self.kl_g = self.netG(
                    self.var_L, img_inverted, mask)
                save_image(self.fake_H, "self.fake_H_pluralistic.png")

            elif self.which_model_G == 'EdgeConnect':
                self.fake_H, self.other_img = self.netG(
                    self.var_L, self.canny_data, self.grayscale_data, mask)

            elif self.which_model_G == 'FRRN':
                self.fake_H, mid_x, mid_mask = self.netG(self.var_L, mask)

            elif self.which_model_G == 'PRVS':
                self.fake_H, _, edge_small, edge_big = self.netG(
                    self.var_L, mask, self.canny_data)

            elif self.which_model_G == 'CSA':
                #out_c, out_r, csa, csa_d
                coarse_result, self.fake_H, csa, csa_d = self.netG(
                    self.var_L, mask)

            elif self.which_model_G == 'atrous':
                self.fake_H = self.netG(self.var_L)

            else:
                print("Selected model is not implemented.")

        # Merge inpainted data with original data in masked region
        self.fake_H = self.var_L * mask + self.fake_H * (1 - mask)
        #save_image(self.fake_H, 'self_fake_H.png')

        #/with self.cast():
        #self.fake_H = self.netG(self.var_L, mask)

        #self.counter += 1
        #save_image(mask, str(self.counter)+'mask_train.png')
        #save_image(self.fake_H, str(self.counter)+'fake_H_train.png')

        # batch (mixup) augmentations
        # cutout-ed pixels are discarded when calculating loss by masking removed pixels
        if aug == "cutout":
            self.fake_H, self.var_H = self.fake_H * mask, self.var_H * mask

        l_g_total = 0
        """
        Calculate and log losses
        """
        loss_results = []
        # training generator and discriminator
        # update generator (on its own if only training generator or alternatively if training GAN)
        if (self.cri_gan is not True) or (step % self.D_update_ratio == 0
                                          and step > self.D_init_iters):
            with self.cast(
            ):  # Casts operations to mixed precision if enabled, else nullcontext
                # regular losses
                loss_results, self.log_dict = self.generatorlosses(
                    self.fake_H, self.var_H, self.log_dict, self.f_low)

                # additional losses, in case a model does output more than a normal image
                ###############################
                # deepfillv2 / global / crfill / CRA
                if self.which_model_G == 'deepfillv2' or self.which_model_G == 'Global' or self.which_model_G == 'crfill' or self.which_model_G == 'CRA':
                    L1Loss = nn.L1Loss()
                    l1_stage1 = L1Loss(self.other_img, self.var_H)

                    self.log_dict.update(l1_stage1=l1_stage1)
                    loss_results.append(l1_stage1)

                # edge-connect
                if self.which_model_G == 'EdgeConnect':
                    L1Loss = nn.L1Loss()
                    l1_edge = L1Loss(self.other_img, self.canny_data)

                    self.log_dict.update(l1_edge=l1_edge)
                    loss_results.append(l1_edge)
                ###############################
                # csa
                if self.which_model_G == 'CSA':
                    #coarse_result, refine_result, csa, csa_d = g_model(masked, mask)
                    L1Loss = nn.L1Loss()
                    recon_loss = L1Loss(coarse_result, self.var_H) + L1Loss(
                        self.fake_H, self.var_H)

                    from models.modules.csa_loss import ConsistencyLoss
                    cons = ConsistencyLoss()

                    cons_loss = cons(csa, csa_d, self.var_H, mask)

                    self.log_dict.update(recon_loss=recon_loss)
                    loss_results.append(recon_loss)
                    self.log_dict.update(cons_loss=cons_loss)
                    loss_results.append(cons_loss)
                ###############################
                # pluralistic (encoder kl loss)
                if self.which_model_G == 'Pluralistic':
                    loss_kl_rec = self.kl_rec.mean()
                    loss_kl_g = self.kl_g.mean()

                    self.log_dict.update(loss_kl_rec=loss_kl_rec)
                    loss_results.append(loss_kl_rec)
                    self.log_dict.update(loss_kl_g=loss_kl_g)
                    loss_results.append(loss_kl_g)
                ###############################
                # deepfillv1
                if self.which_model_G == 'deepfillv1':
                    from models.modules.deepfillv1_loss import ReconLoss
                    ReconLoss_ = ReconLoss(1, 1, 1, 1)
                    reconstruction_loss = ReconLoss_(self.var_H,
                                                     self.other_img,
                                                     self.fake_H, mask)

                    self.log_dict.update(
                        reconstruction_loss=reconstruction_loss)
                    loss_results.append(reconstruction_loss)
                ###############################
                # pennet
                if self.which_model_G == 'pennet':
                    L1Loss = nn.L1Loss()
                    if self.other_img is not None:
                        pyramid_loss = 0
                        for _, f in enumerate(self.other_img):
                            pyramid_loss += L1Loss(
                                f,
                                torch.nn.functional.interpolate(
                                    self.var_H,
                                    size=f.size()[2:4],
                                    mode='bilinear',
                                    align_corners=True))

                    self.log_dict.update(pyramid_loss=pyramid_loss)
                    loss_results.append(pyramid_loss)
                ###############################
                # FRRN
                if self.which_model_G == 'FRRN':
                    L1Loss = nn.L1Loss()
                    # generator step loss
                    for idx in range(len(mid_x) - 1):
                        mid_l1_loss = L1Loss(mid_x[idx] * mid_mask[idx],
                                             self.var_H * mid_mask[idx])

                    self.log_dict.update(mid_l1_loss=mid_l1_loss)
                    loss_results.append(mid_l1_loss)
                ###############################
                # PRVS
                if self.which_model_G == 'PRVS':
                    L1Loss = nn.L1Loss()
                    #from models.modules.PRVS_loss import edge_loss
                    #[edge_small, edge_big]
                    #adv_loss_0 = self.edge_loss(fake_edge[1], real_edge)
                    #dv_loss_1 = self.edge_loss(fake_edge[0], F.interpolate(real_edge, scale_factor = 0.5))

                    #adv_loss_0 = edge_loss(self, edge_big, self.canny_data, self.grayscale_data)
                    #adv_loss_1 = edge_loss(self, edge_small, torch.nn.functional.interpolate(self.canny_data, scale_factor = 0.5))

                    # l1 instead of discriminator loss
                    edge_big_l1 = L1Loss(edge_big, self.canny_data)
                    edge_small_l1 = L1Loss(
                        edge_small,
                        torch.nn.functional.interpolate(self.canny_data,
                                                        scale_factor=0.5))

                    self.log_dict.update(edge_big_l1=edge_big_l1)
                    loss_results.append(edge_big_l1)
                    self.log_dict.update(edge_small_l1=edge_small_l1)
                    loss_results.append(edge_small_l1)
                ###############################

                #for key, value in self.log_dict.items():
                #    print(key, value)

                l_g_total += sum(loss_results) / self.accumulations

                if self.cri_gan:
                    # adversarial loss
                    l_g_gan = self.adversarial(
                        self.fake_H,
                        self.var_ref,
                        netD=self.netD,
                        stage='generator',
                        fsfilter=self.f_high)  # (sr, hr)
                    self.log_dict['l_g_gan'] = l_g_gan.item()
                    l_g_total += l_g_gan / self.accumulations

            #/with self.cast():

            if self.amp:
                # call backward() on scaled loss to create scaled gradients.
                self.amp_scaler.scale(l_g_total).backward()
            else:
                l_g_total.backward()

            # only step and clear gradient if virtual batch has completed
            if (step + 1) % self.accumulations == 0:
                if self.amp:
                    # unscale gradients of the optimizer's params, call
                    # optimizer.step() if no infs/NaNs in gradients, else, skipped
                    self.amp_scaler.step(self.optimizer_G)
                    # Update GradScaler scale for next iteration.
                    self.amp_scaler.update()
                    #TODO: remove. for debugging AMP
                    #print("AMP Scaler state dict: ", self.amp_scaler.state_dict())
                else:
                    self.optimizer_G.step()
                self.optimizer_G.zero_grad()
                self.optGstep = True

        if self.cri_gan:
            # update discriminator
            # unfreeze discriminator
            for p in self.netD.parameters():
                p.requires_grad = True
            l_d_total = 0

            with self.cast(
            ):  # Casts operations to mixed precision if enabled, else nullcontext
                l_d_total, gan_logs = self.adversarial(
                    self.fake_H,
                    self.var_ref,
                    netD=self.netD,
                    stage='discriminator',
                    fsfilter=self.f_high)  # (sr, hr)

                for g_log in gan_logs:
                    self.log_dict[g_log] = gan_logs[g_log]

                l_d_total /= self.accumulations
            #/with autocast():

            if self.amp:
                # call backward() on scaled loss to create scaled gradients.
                self.amp_scaler.scale(l_d_total).backward()
            else:
                l_d_total.backward()

            # only step and clear gradient if virtual batch has completed
            if (step + 1) % self.accumulations == 0:
                if self.amp:
                    # unscale gradients of the optimizer's params, call
                    # optimizer.step() if no infs/NaNs in gradients, else, skipped
                    self.amp_scaler.step(self.optimizer_D)
                    # Update GradScaler scale for next iteration.
                    self.amp_scaler.update()
                else:
                    self.optimizer_D.step()
                self.optimizer_D.zero_grad()
                self.optDstep = True

    def test(self, data):
        """
        # generating random mask for validation
        self.var_L, mask = self.masking_images()
        if self.which_model_G == 'Pluralistic':
          # pluralistic needs the inpainted area as an image and not only the cut-out
          self.var_L, img_inverted, mask = self.masking_images_with_invert()
        else:
          self.var_L, mask = self.masking_images()
        """

        self.mask = data['green_mask'].float().to(self.device).unsqueeze(0)

        if self.which_model_G == 'Pluralistic':
            img_inverted = self.var_L * (1 - self.mask)

        self.var_L = self.var_L * self.mask

        self.netG.eval()
        with torch.no_grad():
            if self.is_train:
                # normal
                if self.which_model_G == 'AdaFill' or self.which_model_G == 'MEDFE' or self.which_model_G == 'RFR' or self.which_model_G == 'LBAM' or self.which_model_G == 'DMFN' or self.which_model_G == 'partial' or self.which_model_G == 'Adaptive' or self.which_model_G == 'DFNet' or self.which_model_G == 'RN':
                    self.fake_H = self.netG(self.var_L, self.mask)
                # 2 rgb images
                elif self.which_model_G == 'CRA' or self.which_model_G == 'pennet' or self.which_model_G == 'deepfillv1' or self.which_model_G == 'deepfillv2' or self.which_model_G == 'Global' or self.which_model_G == 'crfill' or self.which_model_G == 'DeepDFNet':
                    self.fake_H, _ = self.netG(self.var_L, self.mask)

                # special
                elif self.which_model_G == 'Pluralistic':
                    self.fake_H, _, _ = self.netG(self.var_L, img_inverted,
                                                  self.mask)

                elif self.which_model_G == 'EdgeConnect':
                    self.fake_H, _ = self.netG(self.var_L, self.canny_data,
                                               self.grayscale_data, self.mask)

                elif self.which_model_G == 'FRRN':
                    self.fake_H, _, _ = self.netG(self.var_L, self.mask)

                elif self.which_model_G == 'PRVS':
                    self.fake_H, _, _, _ = self.netG(self.var_L, self.mask,
                                                     self.canny_data)

                elif self.which_model_G == 'CSA':
                    _, self.fake_H, _, _ = self.netG(self.var_L, self.mask)

                elif self.which_model_G == 'atrous':
                    self.fake_H = self.netG(self.var_L)
                else:
                    print("Selected model is not implemented.")

        # Merge inpainted data with original data in masked region
        self.fake_H = self.var_L * self.mask + self.fake_H * (1 - self.mask)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_HR=True):
        out_dict = OrderedDict()
        out_dict['LR'] = self.var_L.detach()[0].float().cpu()
        out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        if need_HR:
            out_dict['HR'] = self.var_H.detach()[0].float().cpu()
        #TODO for PPON ?
        #if get stages 1 and 2
        #out_dict['SR_content'] = ...
        #out_dict['SR_structure'] = ...
        return out_dict

    def get_current_visuals_batch(self, need_HR=True):
        out_dict = OrderedDict()
        out_dict['LR'] = self.var_L.detach().float().cpu()
        out_dict['SR'] = self.fake_H.detach().float().cpu()
        if need_HR:
            out_dict['HR'] = self.var_H.detach().float().cpu()
        #TODO for PPON ?
        #if get stages 1 and 2
        #out_dict['SR_content'] = ...
        #out_dict['SR_structure'] = ...
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)

        logger.info('Network G structure: {}, with parameters: {:,d}'.format(
            net_struc_str, n))
        logger.info(s)
        if self.is_train:
            # Discriminator
            if self.cri_gan:
                s, n = self.get_network_description(self.netD)
                if isinstance(self.netD, nn.DataParallel):
                    net_struc_str = '{} - {}'.format(
                        self.netD.__class__.__name__,
                        self.netD.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netD.__class__.__name__)

                logger.info(
                    'Network D structure: {}, with parameters: {:,d}'.format(
                        net_struc_str, n))
                logger.info(s)

            #TODO: feature network is not being trained, is it necessary to visualize? Maybe just name?
            # maybe show the generatorlosses instead?
            '''
            if self.generatorlosses.cri_fea:  # F, Perceptual Network
                #s, n = self.get_network_description(self.netF)
                s, n = self.get_network_description(self.generatorlosses.netF) #TODO
                #s, n = self.get_network_description(self.generatorlosses.loss_list.netF) #TODO
                if isinstance(self.generatorlosses.netF, nn.DataParallel):
                    net_struc_str = '{} - {}'.format(self.generatorlosses.netF.__class__.__name__,
                                                    self.generatorlosses.netF.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.generatorlosses.netF.__class__.__name__)

                logger.info('Network F structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
                logger.info(s)
            '''

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading pretrained model for G [{:s}] ...'.format(
                load_path_G))
            strict = self.opt['path'].get('strict', None)
            self.load_network(load_path_G, self.netG, strict)
        if self.opt['is_train'] and self.opt['train']['gan_weight']:
            load_path_D = self.opt['path']['pretrain_model_D']
            if self.opt['is_train'] and load_path_D is not None:
                logger.info('Loading pretrained model for D [{:s}] ...'.format(
                    load_path_D))
                self.load_network(load_path_D, self.netD)

    def load_swa(self):
        if self.opt['is_train'] and self.opt['use_swa']:
            load_path_swaG = self.opt['path']['pretrain_model_swaG']
            if self.opt['is_train'] and load_path_swaG is not None:
                logger.info(
                    'Loading pretrained model for SWA G [{:s}] ...'.format(
                        load_path_swaG))
                self.load_network(load_path_swaG, self.swa_model)

    def save(self, iter_step, latest=None, loader=None):
        self.save_network(self.netG, 'G', iter_step, latest)
        if self.cri_gan:
            self.save_network(self.netD, 'D', iter_step, latest)
        if self.swa:
            # when training with networks that use BN
            # # Update bn statistics for the swa_model only at the end of training
            # if not isinstance(iter_step, int): #TODO: not sure if it should be done only at the end
            self.swa_model = self.swa_model.cpu()
            torch.optim.swa_utils.update_bn(loader, self.swa_model)
            self.swa_model = self.swa_model.cuda()
            # Check swa BN statistics
            # for module in self.swa_model.modules():
            #     if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
            #         print(module.running_mean)
            #         print(module.running_var)
            #         print(module.momentum)
            #         break
            self.save_network(self.swa_model, 'swaG', iter_step, latest)
Exemplo n.º 13
0
class BaseModel(ConfigModel):
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        self.nb_tasks = len(opt.num_classes)
        self.gpu_ids = opt.gpu_ids
        self.isTrain = opt.isTrain

        self.save_dir = opt.checkpoints_dir  # save all the checkpoints to save_dir

        self.loss_names = []
        self.net_names = []
        self.visual_names = []
        self.optimizers = []
        self.image_paths = []
        self.metric = 0  # used for learning rate policy 'plateau'

        self.target, self.image = None, None
        self.output = None

        self.max_step = 0

    @staticmethod
    def modify_commandline_options(parser, is_train):
        """Add new model-specific options, and rewrite default values for existing options.

        Parameters:
            parser          -- original option parser
            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.

        Returns:
            the modified parser.
        """

        return parser

    @staticmethod
    def default_value(opt):
        return opt

    @exec_times(1)
    def setup(self, task_index):
        """Load and print networks; create schedulers

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        self.task_index = task_index
        logging.info("BaseModel set up")
        if self.isTrain:
            self.schedulers = [
                get_scheduler(optimizer, self.opt)
                for optimizer in self.optimizers
            ]

        if not self.isTrain or self.opt.continue_train:
            self.load_networks(self.opt.load_taskindex, self.opt.load_step,
                               self.opt.load_epoch)

        if self.opt.amp:
            self.scaler = GradScaler()
        self.print_networks()

    def eval(self):
        """Make models eval mode during test time"""
        for name in self.net_names:
            if isinstance(name, str):
                net = getattr(self, "net_" + name)
                net.eval()

    def fit(self):
        """Make models train mode during train time"""
        for name in self.net_names:
            if isinstance(name, str):
                net = getattr(self, "net_" + name)
                net.train()

    def train(self, task_index):

        self.fit()  # set train mode
        self.optimizer.zero_grad(
            set_to_none=True
        )  # clear network's existing gradients, torch.version>=1.7 is ok
        # self.optimizer.zero_grad(set_grads_to_None=True)  # clear network's existing gradients
        # self.optimizer.zero_grad()
        if self.opt.amp:
            with autocast():
                self.forward(
                )  # first call forward to calculate intermediate results
                self.loss_total, self.losses_without_lambda = self.loss_criterion(
                    self.output, self.target, task_index)

        else:
            self.forward()  # 多次执行得到的self.output会不一样???
            self.loss_total, self.losses_without_lambda = self.loss_criterion(
                self.output, self.target, task_index)

        self.backward()  # calculate gradients for network main
        # self.scaler.unscale_(self.optimizer)

        self.step()

        if self.opt.amp:
            self.scaler.update()

    def step(self):
        if self.opt.amp:
            for optimizer in self.optimizers:
                self.scaler.step(optimizer)
        else:
            for optimizer in self.optimizers:
                optimizer.step()  # update gradients for network

    def test(self, visualizer=None):
        """Forward function used in test time.

        This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
        It also calls <compute_visuals> to produce additional visualization results
        """
        self.eval()  # set eval mode
        with torch.no_grad():
            self.forward()
            if visualizer is not None:
                self.compute_visuals(visualizer)

    def update_learning_rate(self):
        """Update learning rates for all the networks; called at the end of every epoch"""
        for scheduler in self.schedulers:
            if self.opt.lr_policy == 'plateau':
                scheduler.step(self.metric)
            else:
                scheduler.step()

        lr = self.optimizers[0].param_groups[0]['lr']
        logging.info('learning rate = %.7f' % lr)

    def get_current_losses(self):
        """Return traning losses_without_lambda / errors. fit.py will print out these errors on console, and save them to a file"""
        errors_ret = OrderedDict()
        for name in self.loss_names:
            if isinstance(name, str):
                errors_ret[name] = getattr(
                    self, name
                )  # float(...) works for both scalar tensor and float number
        return errors_ret

    def save_networks(self, task_index, step, epoch):
        """Save all the networks to the disk.

        Parameters:
            task_index (int) -- current trained task_index
            step (int) -- current trained step
            epoch (Union[int,str]) -- current epoch; used in the file tag '%s_%s_net_%s.pth' % (task_index,epoch, tag)
        """

        task_index, step = int(task_index), int(step)

        for file in os.listdir(self.save_dir):
            if file.endswith('.pth'):
                __task_index, __step, __epoch, *others = file.split('_')
                __task_index, __step, __epoch = int(__task_index), int(
                    __step), int(__epoch)

                if __task_index < task_index \
                        or __task_index == task_index and __step < step \
                        or __task_index == task_index and __step == step and (
                        __epoch < int(epoch) if epoch != 'best' else True):
                    path = os.path.join(self.save_dir, file)
                    logging.info(f'rmdir {path}')
                    # rm_dirs.append(path)
                    os.remove(path)

        for name in self.net_names:
            if isinstance(name, str):
                save_filename = '%s_%s_%s_net_%s.pth' % (task_index, step,
                                                         epoch, name)
                save_path = os.path.join(self.save_dir, save_filename)
                net = getattr(self, "net_" + name)
                if is_gpu_avaliable(self.opt):
                    torch.save(net.module.cpu().state_dict(), save_path)
                    net.to(device=self.opt.device)
                else:
                    torch.save(net.cpu().state_dict(), save_path)

    def load_networks(self, taks_index, step, epoch):
        """Load all the networks from the disk.

        Parameters:
            task_index (int) -- current trained task_index
            epoch (int or str) -- current epoch or best epoch; used in the file tag '{taks_index}_{epoch}_net_{name}.pth' % (epoch, tag)
        """
        for name in self.net_names:
            if isinstance(name, str):
                load_filename = f'{taks_index}_{step}_{epoch}_net_{name}.pth'
                load_path = os.path.join(self.save_dir, load_filename)
                if not os.path.exists(load_path):
                    logging.warning(
                        f'checkpoint path {load_path} is not exists, please checkout it!'
                    )
                else:
                    net = getattr(self, "net_" + name)
                    if isinstance(net, torch.nn.DataParallel):
                        net = net.module
                    logging.info('loading the model from %s' % load_path)
                    state_dict = torch.load(load_path,
                                            map_location=str(self.opt.device))
                    if hasattr(state_dict, '_metadata'):
                        del state_dict._metadata

                    # patch InstanceNorm checkpoints prior to 0.4
                    for key in list(state_dict.keys(
                    )):  # need to copy keys here because we mutate in loop
                        self.__patch_instance_norm_state_dict(
                            state_dict, net, key.split('.'))
                    net.load_state_dict(state_dict)

    def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
        """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
        key = keys[i]
        if i + 1 == len(keys):  # at the end, pointing to a parameter/buffer
            if module.__class__.__name__.startswith('InstanceNorm') and\
                    (key == 'running_mean' or key == 'running_var'):
                if getattr(module, key) is None:
                    state_dict.pop('.'.join(keys))
            if module.__class__.__name__.startswith('InstanceNorm') and\
                    (key == 'num_batches_tracked'):
                state_dict.pop('.'.join(keys))
        else:
            self.__patch_instance_norm_state_dict(state_dict,
                                                  getattr(module, key), keys,
                                                  i + 1)

    def print_networks(self):
        """Print the total number of parameters in the network and (if verbose) network architecture

        Parameters:
            verbose (bool) -- if verbose: print the network architecture
        """
        logging.info('---------- Networks initialized -------------')
        for name in self.net_names:
            if isinstance(name, str):
                net = getattr(self, "net_" + name)
                num_params = 0
                for param in net.parameters():
                    num_params += param.numel()
                logging.info(net)
                logging.info(
                    '[Network %s] Total number of parameters : %.3f M' %
                    (name, num_params / 1e6))
        logging.info('-----------------------------------------------')

    def get_matrix_item(self, task_index) -> MatrixItem:
        """for current task_index, after calling set_data() and forward(), then call <get_matrix_item()> will get MatrixItem """
        if self.output is None or self.target is None:
            raise ValueError(
                f"Expected model.set_data(data) and model.forward() be called before model.get_matrix_item(), but not called {'forward()' if self.output is None else ''} {'and set_data()' if self.target is None else ''}"
            )
        return MatrixItem(self.output[task_index], self.target,
                          self.loss_criterion)

    def init_net_with_dataparaller(self, opt, net):
        """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
        Parameters:
            net (network)      -- the network to be initialized
            init_type (str)    -- the tag of an initialization method: normal | xavier | kaiming | orthogonal
            gain (float)       -- scaling factor for normal, xavier and orthogonal.
            gpu_ids (int list) -- which GPUs the network runs on: e.rg., 0,1,2

        Return an initialized network.
        """

        if len(opt.gpu_ids) > 0:
            assert (torch.cuda.is_available())
            # for none distributed data parallel, only DataParallel
            # net = DataParallel(net, device_ids=opt.gpu_ids)
            # net = DataParallel(net, device_ids=[0])
            net = DataParallel(net)

            # self.cuda()
            # net.to(device=self.opt.device)
            logging.info(f'Data Paraller: DataParallel')
            return net
        else:
            raise ValueError(f'Expected GPU training, but got none gpu_ids')

    def __getattr__(self, item):
        if "loss" in item and hasattr(self.loss_criterion, item):
            return getattr(self.loss_criterion, item)
        for net in self._get_all_nets():
            if "net" in item and hasattr(net, item):
                return getattr(net, item)
        raise AttributeError(
            f"Model {self.opt.model_name} object has no attribute '{item}'")

    def set_data(self, data: PseudoData):
        """Unpack input _data from the dataloader and perform necessary pre-processing steps.
        """
        self.image = data.image
        self.target: MultiOutput = data.target

    def forward(self):
        """Run forward pass. This will be called by both functions <optimize_parameters> and <test>."""
        if self.image is None:
            raise ValueError(
                f"Expected model.set_data(data) be called before forward(), to get data"
            )

        task_outputs = self.net_main(self.image)
        self.output: MultiOutput = MultiOutput(task_outputs)

    # @torchsnooper.snoop()
    def backward(self):
        """Calculate losses_without_lambda, gradients, and update network weights; called in every training iteration"""
        if self.opt.amp:
            for loss in self.losses_without_lambda[:-1]:
                self.scaler.scale(loss).backward(retain_graph=True)
            self.scaler.scale(self.losses_without_lambda[-1]).backward()

        else:
            for loss in self.losses_without_lambda[:-1]:
                loss.backward(retain_graph=True)
            self.losses_without_lambda[-1].backward()

    def _get_all_nets(self):
        nets = self.__dict__
        return [nets["net_" + net_name] for net_name in self.net_names]

    def compute_visuals(self, visualizer):
        """Calculate additional visualization"""

        pass

    def cuda(self, device=None):
        """cuda: net"""
        if device is None:
            device = self.opt.device
        for name in self.net_names:
            if isinstance(name, str):
                net = getattr(self, "net_" + name)
                # net=net.cuda(device)
                net.to(device)
                logging.info(f'[Network {name}] move to cuda({device})')

    def init_optimizers(self, opt):

        self.cuda()

        if self.isTrain:
            if self.opt.optimizer_type == 'adam':
                self.optimizer = torch.optim.Adam(self.net_main.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, 0.999))
            elif opt.optimizer_type == 'sgd':
                self.optimizer = torch.optim.SGD(self.net_main.parameters(),
                                                 lr=opt.lr)
            elif opt.optimizer_type == 'adamw':
                self.optimizer = torch.optim.AdamW(self.net_main.parameters(),
                                                   lr=opt.lr,
                                                   betas=(opt.beta1, 0.999))
            else:
                raise ValueError(
                    f"Expected opt.optimizer_type in ['adam','sgd'], but got {opt.optimizer_type}"
                )
            self.net_main = self.init_net_with_dataparaller(opt, self.net_main)
            self.optimizers = [self.optimizer]
Exemplo n.º 14
0
def main_worker(gpu, ngpus_per_node, args):
    global global_step
    global start_time

    args.gpu = gpu
    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)
                   
    hps = Hyperparameters(args)
    sample_path, save_path, load_path, log_path = mkdir(args)
    if not args.distributed or (args.rank % ngpus_per_node == 0):
        log, log_train, log_eval = get_logger(log_path, args.model_name)
    else:
        log, log_train, log_eval = None, None, None
    model = build_model(hps, log)
    if args.distributed:  # Multiple processes, single GPU per process
        if args.gpu is not None:
            def _transform_(m):
                return nn.parallel.DistributedDataParallel(
                    m, device_ids=[args.gpu], output_device=args.gpu, check_reduction=True)

            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            model.multi_gpu_wrapper(_transform_)
            args.bsz = int(args.bsz / ngpus_per_node)
            args.workers = 0
        else:
            assert 0, "DistributedDataParallel constructor should always set the single device scope"
    elif args.gpu is not None:  # Single process, single GPU per process
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:  # Single process, multiple GPUs per process
        def _transform_(m):
            return nn.DataParallel(m)
        model = model.cuda()
        model.multi_gpu_wrapper(_transform_)


    train_loader, test_loader, synth_loader = load_dataset(args)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scaler = GradScaler()
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
    state = {k: v for k, v in args._get_kwargs()}

    if args.load_step == 0:
        # new model
        global_epoch = 0
        global_step = 0
        actnorm_init(train_loader, model, args.gpu)
    else:
        # saved model
        model, optimizer, scheduler, global_epoch, global_step = load_checkpoint(args.load_step, load_path, model, optimizer, scheduler)
        if log is not None:
            log.write('\n ! --- load the model and continue training --- ! \n')
            log_train.write('\n ! --- load the model and continue training --- ! \n')
            log_eval.write('\n ! --- load the model and continue training --- ! \n')
            log.flush()
            log_train.flush()
            log_eval.flush()

    start_time = time.time()
    dateTime = datetime.datetime.fromtimestamp(start_time).strftime('%Y-%m-%d %H:%M:%S')
    print('training starts at ', dateTime)

    for epoch in range(global_epoch + 1, args.epochs + 1):
        training_epoch_loss = train(args.gpu, epoch, train_loader, synth_loader, sample_path, model, optimizer, scaler, scheduler, log_train, args)

        with torch.no_grad():
            eval_epoch_loss = evaluate(args.gpu, epoch, test_loader, model, log_eval)

        if log is not None:
            state['training_loss'] = training_epoch_loss
            state['eval_loss'] = eval_epoch_loss
            state['epoch'] = epoch
            log.write('%s\n' % json.dumps(state))
            log.flush()
            
        if not args.distributed or (args.rank % ngpus_per_node == 0):
            save_checkpoint(save_path, model, optimizer, scaler, scheduler, global_step, epoch)
            print('Epoch {} Model Saved! Loss : {:.4f}'.format(epoch, eval_epoch_loss))
            with torch.no_grad():
                synthesize(args.gpu, sample_path, synth_loader, model, args.num_sample, args.sr)
        gc.collect()

    if log is not None:
        log_train.close()
        log_eval.close()
        log.close()
Exemplo n.º 15
0
class ModelTrainer(object):
    def __init__(self, speaker_model, optimizer, scheduler, gpu, mixedprec,
                 **kwargs):

        self.__model__ = speaker_model

        Optimizer = importlib.import_module(
            'optimizer.' + optimizer).__getattribute__('Optimizer')
        self.__optimizer__ = Optimizer(self.__model__.parameters(), **kwargs)

        Scheduler = importlib.import_module(
            'scheduler.' + scheduler).__getattribute__('Scheduler')
        self.__scheduler__, self.lr_step = Scheduler(self.__optimizer__,
                                                     **kwargs)

        self.scaler = GradScaler()

        self.gpu = gpu

        self.mixedprec = mixedprec

        assert self.lr_step in ['epoch', 'iteration']

    # ## ===== ===== ===== ===== ===== ===== ===== =====
    # ## Train network
    # ## ===== ===== ===== ===== ===== ===== ===== =====

    def train_network(self, loader, verbose):

        self.__model__.train()

        stepsize = loader.batch_size

        counter = 0
        index = 0
        loss = 0
        top1 = 0  # EER or accuracy

        tstart = time.time()

        for data, data_label in loader:

            data = data.transpose(1, 0)

            self.__model__.zero_grad()

            label = torch.LongTensor(data_label).cuda()

            if self.mixedprec:
                with autocast():
                    nloss, prec1 = self.__model__(data, label)
                self.scaler.scale(nloss).backward()
                self.scaler.step(self.__optimizer__)
                self.scaler.update()
            else:
                nloss, prec1 = self.__model__(data, label)
                nloss.backward()
                self.__optimizer__.step()

            loss += nloss.detach().cpu()
            top1 += prec1
            counter += 1
            index += stepsize

            telapsed = time.time() - tstart
            tstart = time.time()

            if verbose:
                sys.stdout.write("\rProcessing (%d) " % (index))
                sys.stdout.write(
                    "Loss %f TEER/TAcc %2.3f%% - %.2f Hz " %
                    (loss / counter, top1 / counter, stepsize / telapsed))
                sys.stdout.flush()

            if self.lr_step == 'iteration': self.__scheduler__.step()

        if self.lr_step == 'epoch': self.__scheduler__.step()

        sys.stdout.write("\n")

        return (loss / counter, top1 / counter)

    ## ===== ===== ===== ===== ===== ===== ===== =====
    ## Evaluate from list
    ## ===== ===== ===== ===== ===== ===== ===== =====

    def evaluateFromList(self,
                         test_list,
                         test_path,
                         nDataLoaderThread,
                         print_interval=100,
                         num_eval=10,
                         **kwargs):

        self.__model__.eval()

        lines = []
        files = []
        feats = {}
        tstart = time.time()

        ## Read all lines
        with open(test_list) as f:
            lines = f.readlines()

        ## Get a list of unique file names
        files = sum([x.strip().split()[-2:] for x in lines], [])
        setfiles = list(set(files))
        setfiles.sort()

        ## Define test data loader
        test_dataset = test_dataset_loader(setfiles,
                                           test_path,
                                           num_eval=num_eval,
                                           **kwargs)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=1,
            shuffle=False,
            num_workers=nDataLoaderThread,
            drop_last=False,
        )

        ## Extract features for every image
        for idx, data in enumerate(test_loader):
            inp1 = data[0][0].cuda()
            ref_feat = self.__model__(inp1).detach().cpu()
            feats[data[1][0]] = ref_feat
            telapsed = time.time() - tstart

            if idx % print_interval == 0:
                sys.stdout.write(
                    "\rReading %d of %d: %.2f Hz, embedding size %d" %
                    (idx, len(setfiles), idx / telapsed, ref_feat.size()[1]))

        print('')
        all_scores = []
        all_labels = []
        all_trials = []
        tstart = time.time()

        ## Read files and compute all scores
        for idx, line in enumerate(lines):

            data = line.split()

            ## Append random label if missing
            if len(data) == 2: data = [random.randint(0, 1)] + data

            ref_feat = feats[data[1]].cuda()
            com_feat = feats[data[2]].cuda()

            if self.__model__.module.__L__.test_normalize:
                ref_feat = F.normalize(ref_feat, p=2, dim=1)
                com_feat = F.normalize(com_feat, p=2, dim=1)

            dist = F.pairwise_distance(ref_feat.unsqueeze(-1),
                                       com_feat.unsqueeze(-1).transpose(
                                           0, 2)).detach().cpu().numpy()

            score = -1 * numpy.mean(dist)

            all_scores.append(score)
            all_labels.append(int(data[0]))
            all_trials.append(data[1] + " " + data[2])

            if idx % print_interval == 0:
                telapsed = time.time() - tstart
                sys.stdout.write("\rComputing %d of %d: %.2f Hz" %
                                 (idx, len(lines), idx / telapsed))
                sys.stdout.flush()

        print('')

        return (all_scores, all_labels, all_trials)

    ## ===== ===== ===== ===== ===== ===== ===== =====
    ## Save parameters
    ## ===== ===== ===== ===== ===== ===== ===== =====

    def saveParameters(self, path):

        torch.save(self.__model__.module.state_dict(), path)

    ## ===== ===== ===== ===== ===== ===== ===== =====
    ## Load parameters
    ## ===== ===== ===== ===== ===== ===== ===== =====

    def loadParameters(self, path):

        self_state = self.__model__.module.state_dict()
        loaded_state = torch.load(path, map_location="cuda:%d" % self.gpu)
        for name, param in loaded_state.items():
            origname = name
            if name not in self_state:
                name = name.replace("module.", "")

                if name not in self_state:
                    print("%s is not in the model." % origname)
                    continue

            if self_state[name].size() != loaded_state[origname].size():
                print("Wrong parameter length: %s, model: %s, loaded: %s" %
                      (origname, self_state[name].size(),
                       loaded_state[origname].size()))
                continue

            self_state[name].copy_(param)
Exemplo n.º 16
0
def train(fold):
    args = get_args()
    with open(args.config) as file:
        config_file = yaml.load(file, Loader=yaml.FullLoader)

    wandb.init(
        project="siim2020",
        entity="siim_melanoma",
        # name=f"20200718-effb0-adamw-consineaneal-{fold}",
        name=f"2017-2018-rexnet-test-{fold}",
        #name=f"swav-test-{fold}",
        #name=f"RAdam-b6-384x384-{fold}"
    )
    config = wandb.config  # Initialize config
    config.update(config_file)
    device = config.device

    model_path = config.model_path.format(fold)

    seed_everything(config.seed)
    df = pd.read_csv(config.train_csv_fold)
    df_train = df[df.kfold != fold].reset_index(drop=True)
    df_train["image_name"] = config.training_data_path + df_train[
        "image_name"] + ".jpg"

    if config.supplement_data["use_supplement"]:
        print(f"training shape before merge {df_train.shape}")
        df_supplement = pd.read_csv(config.supplement_data["csv_file"])
        df_supplement = df_supplement[df_supplement["tfrecord"] % 2 == 0]
        df_supplement = df_supplement[df_supplement["target"] == 1]
        df_supplement["image_name"] = (config.supplement_data["file_path"] +
                                       df_supplement["image_name"] + ".jpg")
        df_train = pd.concat([df_train, df_supplement]).reset_index(drop=True)
        df_train = df_train.sample(
            frac=1, random_state=config.seed).reset_index(drop=True)
        del df_supplement
        print(f"training shape after merge {df_train.shape}")

    df_valid = df[df.kfold == fold].reset_index(drop=True)
    df_valid["image_name"] = config.training_data_path + df_valid[
        "image_name"] + ".jpg"

    if config.use_metadata:
        df_train, meta_features = get_meta_feature(df_train)
        df_valid, _ = get_meta_feature(df_valid)
    else:
        meta_features = None

    model = get_model(
        config.model_backbone,
        config.model_name,
        config.num_classes,
        config.input_size,
        config.use_metadata,
        meta_features,
    )

    model = model.to(config.device)
    print("watching model")
    wandb.watch(model, log="all")

    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    train_aug = albumentations.Compose([
        AdvancedHairAugmentation(hairs_folder="../input/melanoma-hairs/"),
        # albumentations.augmentations.transforms.CenterCrop(64, 64, p=0.8),
        albumentations.augmentations.transforms.RandomBrightnessContrast(),
        albumentations.augmentations.transforms.HueSaturationValue(),
        # Microscope(p=0.4),
        albumentations.augmentations.transforms.RandomResizedCrop(
            config.input_size, config.input_size, scale=(0.7, 1.0), p=0.4),
        albumentations.augmentations.transforms.VerticalFlip(p=0.4),
        albumentations.augmentations.transforms.Cutout(p=0.3),  # doesnt work
        albumentations.ShiftScaleRotate(shift_limit=0.0625,
                                        scale_limit=0.1,
                                        rotate_limit=15),
        albumentations.Flip(p=0.5),
        RandomAugMix(severity=7, width=7, alpha=5, p=0.3),
        # albumentations.augmentations.transforms.Resize(
        #    config.input_size, config.input_size, p=1
        # ),
        albumentations.Normalize(mean,
                                 std,
                                 max_pixel_value=255.0,
                                 always_apply=True),
    ])

    valid_aug = albumentations.Compose([
        albumentations.Normalize(mean,
                                 std,
                                 max_pixel_value=255.0,
                                 always_apply=True),
    ])

    train_images = df_train.image_name.values.tolist()
    # train_images = [
    #    os.path.join(config.training_data_path, i + ".jpg") for i in train_images
    # ]
    train_targets = df_train.target.values

    valid_images = df_valid.image_name.values.tolist()
    # valid_images = [
    #    os.path.join(config.training_data_path, i + ".jpg") for i in valid_images
    # ]
    valid_targets = df_valid.target.values

    train_dataset = ClassificationDataset(
        image_paths=train_images,
        targets=train_targets,
        resize=None,
        augmentations=train_aug,
        meta_features=meta_features,
        df_meta_features=df_train,
    )

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        # num_workers=4,
        num_workers=1,
        pin_memory=True,
        shuffle=True,
        #sampler=BalanceClassSampler(labels=train_targets, mode="upsampling"),
        drop_last=True,
    )

    valid_dataset = ClassificationDataset(
        image_paths=valid_images,
        targets=valid_targets,
        resize=None,
        augmentations=valid_aug,
        meta_features=meta_features,
        df_meta_features=df_valid,
    )

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.test_batch_size,
        shuffle=False,
        # num_workers=4,
        num_workers=1,
        pin_memory=True,
        # drop_last=True
    )

    #optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    optimizer = RAdam(model.parameters(), lr=config.lr)
    if config.swa["use_swa"]:
        optimizer = SWA(optimizer, swa_start=12, swa_freq=1)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           patience=2,
                                                           threshold=0.0001,
                                                           mode="max")
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    #    optimizer, len(train_loader) * config.epochs
    # )

    #scheduler = torch.optim.lr_scheduler.CyclicLR(
    #   optimizer,
    #   base_lr=config.lr / 10,
    #   max_lr=config.lr * 100,
    #   mode="triangular2",
    #   cycle_momentum=False,
    #)

    #scheduler = torch.optim.lr_scheduler.OneCycleLR(
    #    optimizer, max_lr=3e-3, steps_per_epoch=len(train_loader), epochs=config.epochs
    #)

    es = EarlyStopping(patience=6, mode="max")
    if config.fp16:
        print("************* using fp16 *************")
        scaler = GradScaler()
    else:
        scaler = False

    for epoch in range(config.epochs):
        train_loss = Engine.train(
            train_loader,
            model,
            optimizer,
            device=config.device,
            wandb=wandb,
            accumulation_steps=config.accumulation_steps,
            fp16=config.fp16,
            scaler=scaler,
        )
        predictions, valid_loss = Engine.evaluate(
            valid_loader,
            model,
            device=config.device,
            wandb=wandb,
            epoch=epoch,
            upload_image=False,
            use_sigmoid=True,
        )
        predictions = np.vstack((predictions)).ravel()

        auc = metrics.roc_auc_score(valid_targets, predictions)
        print(f"Epoch = {epoch}, AUC = {auc}")
        wandb.log({
            "valid_auc": auc,
        })

        scheduler.step(auc)

        es(auc, model, model_path=model_path)
        if es.early_stop:
            print("Early stopping")
            break
    if config.swa["use_swa"]:
        print("saving the model using SWA")
        optimizer.swap_swa_sgd()
        torch.save(model.state_dict(), config.swa["model_path"].format(fold))

    evaluate_for_best_epoch(
        fold,
        model_path,
        config.device,
        valid_loader,
        config.model_name,
        valid_targets,
        "final",
        meta_features=meta_features,
    )
    if config.swa["use_swa"]:
        model_path = config.swa["model_path"].format(fold)
        evaluate_for_best_epoch(
            fold,
            model_path,
            config.device,
            valid_loader,
            config.model_name,
            valid_targets,
            "swa",
            meta_features=meta_features,
        )
def train_job(model_name, train_df, valid_df, model_ckpt=None, log=True):

    if log:
        neptune.set_project("utsav/wheat-det")
        neptune.init("utsav/wheat-det", api_token=NEPTUNE_API_TOKEN)
        neptune.create_experiment(
            FLAGS["exp_name"],
            exp_description,
            params=FLAGS,
            upload_source_files="*.txt",
        )
    best_score = 0.0
    start_epoch = 0

    datasets = get_training_datasets(train_df, valid_df)
    train_loader = DataLoader(
        datasets["train"],
        batch_size=FLAGS["batch_size"],
        num_workers=FLAGS["num_workers"],
        shuffle=True,  # sampler=sampler, #
        collate_fn=collate_fn,
    )
    val_loader = DataLoader(
        datasets["valid"],
        batch_size=FLAGS["batch_size"] * 2,
        shuffle=False,
        num_workers=FLAGS["num_workers"],
        collate_fn=collate_fn,
    )

    if model_ckpt is not None:
        model = get_train_model(model_name, model_ckpt)
    else:
        model = get_train_model(model_name)
    model.to(device)

    optimizer = Ranger(
        model.parameters(),
        lr=FLAGS["learning_rate"],
        alpha=0.5,
        k=6,
        N_sma_threshhold=5,
        weight_decay=FLAGS["weight_decay"],
    )

    scheduler = lr_scheduler.ReduceLROnPlateau(
        optimizer,
        "min",
        factor=0.5,
        verbose=True,
        patience=FLAGS["scheduler_pat"],
    )

    scaler = GradScaler()
    es = 0
    for epoch in range(start_epoch, FLAGS["num_epochs"]):

        print("-" * 27 + f"Epoch #{epoch+1} started" + "-" * 27)

        train_loss = train_one_epoch(
            train_loader,
            model,
            optimizer,
            device,
            scaler,
            scheduler=None,
            log=log,
        )
        print(f"Average loss for epoch #{epoch+1} : {train_loss}")

        val_metric, val_loss = val_one_epoch(model, val_loader)
        scheduler.step(val_loss)

        print(f"metric/val : {val_metric}")
        print(f"loss/val : {val_loss}")

        if log:
            neptune.log_metric("metric/val", val_metric)
            neptune.log_metric("loss/val", val_loss)

        # if epoch==FLAGS['unfreeze_epoch']:
        #    model = unfreeze_all_layers(model)

        if (val_metric > best_score) or (best_score - val_metric < 0.01):
            es = 0
            if val_metric > best_score:
                best_score = val_metric
            if epoch > 9:
                save_upload(
                    model,
                    optimizer,
                    epoch,
                    scheduler,
                    val_metric,
                    exp_name=FLAGS["exp_name"],
                )
        # else:
        #    if epoch>24:
        #        es+=1
        # if es > FLAGS['early_stop_count']:
        #    print('Early stopping...')
        #    break

        print("-" * 28 + f"Epoch #{epoch+1} ended" + "-" * 28)

    neptune.stop()
Exemplo n.º 18
0
    def train(self, args, logger=None):
        """
        Train function of FixMatch.
        From data_loader, it inference training data, computes losses, and update the networks.
        """
        ngpus_per_node = torch.cuda.device_count()

        # lb: labeled, ulb: unlabeled
        self.train_model.train()

        # for gpu profiling
        start_batch = torch.cuda.Event(enable_timing=True)
        end_batch = torch.cuda.Event(enable_timing=True)
        start_run = torch.cuda.Event(enable_timing=True)
        end_run = torch.cuda.Event(enable_timing=True)

        start_batch.record()
        best_eval_acc, best_it = 0.0, 0

        scaler = GradScaler()
        amp_cm = autocast if args.amp else contextlib.nullcontext

        for (x_lb, y_lb, idx), (x_ulb_w, x_ulb_s,
                                _) in zip(self.loader_dict['train_lb'],
                                          self.loader_dict['train_ulb']):

            # prevent the training iterations exceed args.num_train_iter
            if self.it > args.num_train_iter:
                break

            end_batch.record()
            torch.cuda.synchronize()
            start_run.record()

            num_lb = x_lb.shape[0]
            num_ulb = x_ulb_w.shape[0]
            assert num_ulb == x_ulb_s.shape[0]

            x_lb, x_ulb_w, x_ulb_s = x_lb.cuda(args.gpu), x_ulb_w.cuda(
                args.gpu), x_ulb_s.cuda(args.gpu)
            y_lb = y_lb.cuda(args.gpu)
            idx = idx.cuda(args.gpu)

            inputs = torch.cat((x_lb, x_ulb_w, x_ulb_s))

            # inference and calculate sup/unsup losses
            with amp_cm():
                # logits = self.train_model(inputs)
                # logits_x_lb = logits[:num_lb]
                # logits_x_ulb_w, logits_x_ulb_s = logits[num_lb:].chunk(2)
                # del logits
                logits_x_lb, logits_x_ulb_w, logits_x_ulb_s, sup_soft_labels = self.train_model(
                    inputs, num_lb=num_lb, it=self.it, sup_index=idx)

                # hyper-params for update
                T = self.t_fn(self.it)
                p_cutoff = self.p_fn(self.it)

                if sup_soft_labels is not None:
                    sup_loss = sat_loss(logits_x_lb, sup_soft_labels)
                else:
                    sup_loss = ce_loss(logits_x_lb, y_lb, reduction='mean')

                unsup_loss, mask = consistency_loss(
                    logits_x_ulb_w,
                    logits_x_ulb_s,
                    'ce',
                    T,
                    p_cutoff,
                    use_hard_labels=args.hard_label)

                total_loss = sup_loss + self.lambda_u * unsup_loss

            # parameter updates
            if args.amp:
                scaler.scale(total_loss).backward()
                scaler.step(self.optimizer)
                scaler.update()
            else:
                total_loss.backward()
                self.optimizer.step()

            self.scheduler.step()
            self.train_model.zero_grad()

            with torch.no_grad():
                self._eval_model_update()

            end_run.record()
            torch.cuda.synchronize()

            # tensorboard_dict update
            tb_dict = {}
            tb_dict['train/sup_loss'] = sup_loss.detach()
            tb_dict['train/unsup_loss'] = unsup_loss.detach()
            tb_dict['train/total_loss'] = total_loss.detach()
            tb_dict['train/mask_ratio'] = 1.0 - mask.detach()
            tb_dict['lr'] = self.optimizer.param_groups[0]['lr']
            tb_dict['train/prefecth_time'] = start_batch.elapsed_time(
                end_batch) / 1000.
            tb_dict['train/run_time'] = start_run.elapsed_time(end_run) / 1000.

            if self.it % self.num_eval_iter == 0:
                eval_dict = self.evaluate(args=args)
                tb_dict.update(eval_dict)

                save_path = os.path.join(args.save_dir, args.save_name)

                if tb_dict['eval/top-1-acc'] > best_eval_acc:
                    best_eval_acc = tb_dict['eval/top-1-acc']
                    best_it = self.it

                self.print_fn(
                    f"{self.it} iteration, USE_EMA: {hasattr(self, 'eval_model')}, {tb_dict}, BEST_EVAL_ACC: {best_eval_acc}, at {best_it} iters"
                )

            if not args.multiprocessing_distributed or \
                    (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):

                if self.it % self.num_eval_iter == 0:
                    self.save_model('model_last.pth', save_path)

                if self.it % self.num_eval_iter == 0:
                    self.save_model('model_{}.pth'.format(self.it), save_path)

                if self.it == best_it:
                    self.save_model('model_best.pth', save_path)

                if not self.tb_log is None:
                    self.tb_log.update(tb_dict, self.it)

            self.it += 1
            del tb_dict
            start_batch.record()
            if self.it > 2**19:
                self.num_eval_iter = 1000

        eval_dict = self.evaluate(args=args)
        eval_dict.update({
            'eval/best_acc': best_eval_acc,
            'eval/best_it': best_it
        })
        return eval_dict
Exemplo n.º 19
0
    def _train(self, rank):
        _seed_everything(self._data_shuffle_seed)
        self._setup_ddp(rank)
        self._rank = rank
        self._scaler = GradScaler()
        self._model = self._get_model(self._rank)
        self._optimizer = AdamW(params=self._model.parameters(),
                                lr=self._learning_rate)

        self._global_step = 0
        self._samples_seen = 0

        self._train_dl = self._get_dataloader(is_train=True, samples_offset=0)
        self._valid_dl = self._get_dataloader(is_train=False, samples_offset=0)

        steps_per_epoch = len(self._train_dl)
        num_training_steps = steps_per_epoch * self._n_epochs
        num_warmup_steps = self._warmup_ratio * num_training_steps
        self._scheduler = get_linear_schedule_with_warmup(
            optimizer=self._optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps)

        if self._checkpoint_file_path.is_file():
            self._load_checkpoint()
        elif self._init_weights_from_checkpoint:
            self._load_only_weights_from_checkpoint()

        while True:
            if self._rank == 0:
                self._writer = self._writer or SummaryWriter(
                    self._experiment_dir / 'tb_logs')
                self._train_dl = tqdm.tqdm(self._train_dl,
                                           desc='Train step',
                                           total=num_training_steps,
                                           position=1,
                                           initial=self._global_step)

            for i_step, model_input in enumerate(self._train_dl):
                train_losses_dict = self._train_step(model_input)

                if rank == 0:
                    self._train_dl.set_postfix({
                        'samples_seen':
                        self._samples_seen,
                        'epoch':
                        self._global_step / steps_per_epoch
                    })
                    self._write_tb_logs(train_losses_dict)
                    self._write_tb_logs({
                        'learning-rate':
                        self._optimizer.param_groups[0]['lr']
                    })
                    self._write_tb_logs(
                        {'max_seq_len': model_input.input_ids.size()[1]})

                if self._rank == 0 and self._global_step % self._validate_each_n_steps == 0:
                    valid_loss = self._validate()
                    self._save_checkpoint()
                    self._write_tb_logs({'loss/valid': valid_loss})

                if self._global_step >= num_training_steps:
                    break

        dist.destroy_process_group()
Exemplo n.º 20
0
class Trainer:
    """
    Implements the training logic. Some common configuration (checkpointing frequency, path, validation frequency)
    is done by checking util.common_opts that is set via the CL.
    """

    def __init__(
        self,
        game: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        train_data: DataLoader,
        optimizer_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
        validation_data: Optional[DataLoader] = None,
        device: torch.device = None,
        callbacks: Optional[List[Callback]] = None,
        grad_norm: float = None,
        aggregate_interaction_logs: bool = True,
    ):
        """
        :param game: A nn.Module that implements forward(); it is expected that forward returns a tuple of (loss, d),
            where loss is differentiable loss to be minimized and d is a dictionary (potentially empty) with auxiliary
            metrics that would be aggregated and reported
        :param optimizer: An instance of torch.optim.Optimizer
        :param optimizer_scheduler: An optimizer scheduler to adjust lr throughout training
        :param train_data: A DataLoader for the training set
        :param validation_data: A DataLoader for the validation set (can be None)
        :param device: A torch.device on which to tensors should be stored
        :param callbacks: A list of egg.core.Callback objects that can encapsulate monitoring or checkpointing
        """
        self.game = game
        self.optimizer = optimizer
        self.optimizer_scheduler = optimizer_scheduler
        self.train_data = train_data
        self.validation_data = validation_data
        common_opts = get_opts()
        self.validation_freq = common_opts.validation_freq
        self.device = common_opts.device if device is None else device

        self.should_stop = False
        self.start_epoch = 0  # Can be overwritten by checkpoint loader
        self.callbacks = callbacks if callbacks else []
        self.grad_norm = grad_norm
        self.aggregate_interaction_logs = aggregate_interaction_logs

        self.update_freq = common_opts.update_freq

        if common_opts.load_from_checkpoint is not None:
            print(
                f"# Initializing model, trainer, and optimizer from {common_opts.load_from_checkpoint}"
            )
            self.load_from_checkpoint(common_opts.load_from_checkpoint)

        self.distributed_context = common_opts.distributed_context
        if self.distributed_context.is_distributed:
            print("# Distributed context: ", self.distributed_context)

        if self.distributed_context.is_leader and not any(
            isinstance(x, CheckpointSaver) for x in self.callbacks
        ):
            if common_opts.preemptable:
                assert (
                    common_opts.checkpoint_dir
                ), "checkpointing directory has to be specified"
                d = get_preemptive_checkpoint_dir(common_opts.checkpoint_dir)
                self.checkpoint_path = d
                self.load_from_latest(d)
            else:
                self.checkpoint_path = (
                    None
                    if common_opts.checkpoint_dir is None
                    else pathlib.Path(common_opts.checkpoint_dir)
                )

            if self.checkpoint_path:
                checkpointer = CheckpointSaver(
                    checkpoint_path=self.checkpoint_path,
                    checkpoint_freq=common_opts.checkpoint_freq,
                )
                self.callbacks.append(checkpointer)

        if self.distributed_context.is_leader and common_opts.tensorboard:
            assert (
                common_opts.tensorboard_dir
            ), "tensorboard directory has to be specified"
            tensorboard_logger = TensorboardLogger()
            self.callbacks.append(tensorboard_logger)

        if self.callbacks is None:
            self.callbacks = [
                ConsoleLogger(print_train_loss=False, as_json=False),
            ]

        if self.distributed_context.is_distributed:
            device_id = self.distributed_context.local_rank
            torch.cuda.set_device(device_id)
            self.game.to(device_id)

            # NB: here we are doing something that is a bit shady:
            # 1/ optimizer was created outside of the Trainer instance, so we don't really know
            #    what parameters it optimizes. If it holds something what is not within the Game instance
            #    then it will not participate in distributed training
            # 2/ if optimizer only holds a subset of Game parameters, it works, but somewhat non-documentedly.
            #    In fact, optimizer would hold parameters of non-DistributedDataParallel version of the Game. The
            #    forward/backward calls, however, would happen on the DistributedDataParallel wrapper.
            #    This wrapper would sync gradients of the underlying tensors - which are the ones that optimizer
            #    holds itself.  As a result it seems to work, but only because DDP doesn't take any tensor ownership.

            self.game = torch.nn.parallel.DistributedDataParallel(
                self.game,
                device_ids=[device_id],
                output_device=device_id,
                find_unused_parameters=True,
            )
            self.optimizer.state = move_to(self.optimizer.state, device_id)

        else:
            self.game.to(self.device)
            # NB: some optimizers pre-allocate buffers before actually doing any steps
            # since model is placed on GPU within Trainer, this leads to having optimizer's state and model parameters
            # on different devices. Here, we protect from that by moving optimizer's internal state to the proper device
            self.optimizer.state = move_to(self.optimizer.state, self.device)

        if common_opts.fp16:
            self.scaler = GradScaler()
        else:
            self.scaler = None

    def eval(self, data=None):
        mean_loss = 0.0
        interactions = []
        n_batches = 0
        validation_data = self.validation_data if data is None else data
        self.game.eval()
        with torch.no_grad():
            for batch in validation_data:
                if not isinstance(batch, Batch):
                    batch = Batch(*batch)
                batch = batch.to(self.device)
                optimized_loss, interaction = self.game(*batch)
                if (
                    self.distributed_context.is_distributed
                    and self.aggregate_interaction_logs
                ):
                    interaction = Interaction.gather_distributed_interactions(
                        interaction
                    )
                interaction = interaction.to("cpu")
                mean_loss += optimized_loss

                for callback in self.callbacks:
                    callback.on_batch_end(
                        interaction, optimized_loss, n_batches, is_training=False
                    )

                interactions.append(interaction)
                n_batches += 1

        mean_loss /= n_batches
        full_interaction = Interaction.from_iterable(interactions)

        return mean_loss.item(), full_interaction

    def train_epoch(self):
        mean_loss = 0
        n_batches = 0
        interactions = []

        self.game.train()

        self.optimizer.zero_grad()

        for batch_id, batch in enumerate(self.train_data):
            if not isinstance(batch, Batch):
                batch = Batch(*batch)
            batch = batch.to(self.device)

            context = autocast() if self.scaler else nullcontext()
            with context:
                optimized_loss, interaction = self.game(*batch)

                if self.update_freq > 1:
                    # throughout EGG, we minimize _mean_ loss, not sum
                    # hence, we need to account for that when aggregating grads
                    optimized_loss = optimized_loss / self.update_freq

            if self.scaler:
                self.scaler.scale(optimized_loss).backward()
            else:
                optimized_loss.backward()

            if batch_id % self.update_freq == self.update_freq - 1:
                if self.scaler:
                    self.scaler.unscale_(self.optimizer)

                if self.grad_norm:
                    torch.nn.utils.clip_grad_norm_(
                        self.game.parameters(), self.grad_norm
                    )
                if self.scaler:
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                else:
                    self.optimizer.step()

                self.optimizer.zero_grad()

            n_batches += 1
            mean_loss += optimized_loss.detach()
            if (
                self.distributed_context.is_distributed
                and self.aggregate_interaction_logs
            ):
                interaction = Interaction.gather_distributed_interactions(interaction)
            interaction = interaction.to("cpu")

            for callback in self.callbacks:
                callback.on_batch_end(interaction, optimized_loss, batch_id)

            interactions.append(interaction)

        if self.optimizer_scheduler:
            self.optimizer_scheduler.step()

        mean_loss /= n_batches
        full_interaction = Interaction.from_iterable(interactions)
        return mean_loss.item(), full_interaction

    def train(self, n_epochs):
        for callback in self.callbacks:
            callback.on_train_begin(self)

        for epoch in range(self.start_epoch, n_epochs):
            for callback in self.callbacks:
                callback.on_epoch_begin(epoch + 1)

            train_loss, train_interaction = self.train_epoch()

            for callback in self.callbacks:
                callback.on_epoch_end(train_loss, train_interaction, epoch + 1)

            validation_loss = validation_interaction = None
            if (
                self.validation_data is not None
                and self.validation_freq > 0
                and (epoch + 1) % self.validation_freq == 0
            ):
                for callback in self.callbacks:
                    callback.on_validation_begin(epoch + 1)
                validation_loss, validation_interaction = self.eval()

                for callback in self.callbacks:
                    callback.on_validation_end(
                        validation_loss, validation_interaction, epoch + 1
                    )

            if self.should_stop:
                for callback in self.callbacks:
                    callback.on_early_stopping(
                        train_loss,
                        train_interaction,
                        epoch + 1,
                        validation_loss,
                        validation_interaction,
                    )
                break

        for callback in self.callbacks:
            callback.on_train_end()

    def load(self, checkpoint: Checkpoint):
        self.game.load_state_dict(checkpoint.model_state_dict)
        self.optimizer.load_state_dict(checkpoint.optimizer_state_dict)
        if checkpoint.optimizer_scheduler_state_dict:
            self.optimizer_scheduler.load_state_dict(
                checkpoint.optimizer_scheduler_state_dict
            )
        self.start_epoch = checkpoint.epoch

    def load_from_checkpoint(self, path):
        """
        Loads the game, agents, and optimizer state from a file
        :param path: Path to the file
        """
        print(f"# loading trainer state from {path}")
        checkpoint = torch.load(path)
        self.load(checkpoint)

    def load_from_latest(self, path):
        latest_file, latest_time = None, None

        for file in path.glob("*.tar"):
            creation_time = os.stat(file).st_ctime
            if latest_time is None or creation_time > latest_time:
                latest_file, latest_time = file, creation_time

        if latest_file is not None:
            self.load_from_checkpoint(latest_file)
Exemplo n.º 21
0
class ConfigurableStep(Module):

    def __init__(self, opt_step, env):
        super(ConfigurableStep, self).__init__()

        self.step_opt = opt_step
        self.env = env
        self.opt = env['opt']
        self.gen_outputs = opt_step['generator_outputs']
        self.loss_accumulator = LossAccumulator()
        self.optimizers = None
        self.scaler = GradScaler(enabled=self.opt['fp16'])
        self.grads_generated = False
        self.min_total_loss = opt_step['min_total_loss'] if 'min_total_loss' in opt_step.keys() else -999999999
        self.clip_grad_eps = opt_get(opt_step, ['clip_grad_eps'], None)

        # This is a half-measure that can be used between anomaly_detection and running a potentially problematic
        # trainer bare. With this turned on, the optimizer will not step() if a nan grad is detected. If a model trips
        # this warning 10 times in a row, the training session is aborted and the model state is saved. This has a
        # noticeable affect on training speed, but nowhere near as bad as anomaly_detection.
        self.check_grads_for_nan = opt_get(opt_step, ['check_grads_for_nan'], False)
        self.nan_counter = 0

        self.injectors = []
        if 'injectors' in self.step_opt.keys():
            injector_names = []
            for inj_name, injector in self.step_opt['injectors'].items():
                assert inj_name not in injector_names  # Repeated names are always an error case.
                injector_names.append(inj_name)
                self.injectors.append(create_injector(injector, env))

        losses = []
        self.weights = {}
        if 'losses' in self.step_opt.keys():
            for loss_name, loss in self.step_opt['losses'].items():
                assert loss_name not in self.weights.keys()  # Repeated names are always an error case.
                losses.append((loss_name, create_loss(loss, env)))
                self.weights[loss_name] = loss['weight']
        self.losses = OrderedDict(losses)

    def get_network_for_name(self, name):
        return self.env['generators'][name] if name in self.env['generators'].keys() \
                else self.env['discriminators'][name]

    # Subclasses should override this to define individual optimizers. They should all go into self.optimizers.
    #  This default implementation defines a single optimizer for all Generator parameters.
    #  Must be called after networks are initialized and wrapped.
    def define_optimizers(self):
        opt_configs = [opt_get(self.step_opt, ['optimizer_params'], None)]
        self.optimizers = []
        if opt_configs[0] is None:
            return
        training = self.step_opt['training']
        training_net = self.get_network_for_name(training)
        nets = [training_net]
        training = [training]
        for net_name, net, opt_config in zip(training, nets, opt_configs):
            # Configs can organize parameters by-group and specify different learning rates for each group. This only
            # works in the model specifically annotates which parameters belong in which group using PARAM_GROUP.
            optim_params = {'default': {'params': [], 'lr': opt_config['lr']}}
            if opt_config is not None and 'param_groups' in opt_config.keys():
                for k, pg in opt_config['param_groups'].items():
                    optim_params[k] = {'params': [], 'lr': pg['lr']}

            for k, v in net.named_parameters():  # can optimize for a part of the model
                # Make some inference about these parameters, which can be used by some optimizers to treat certain
                # parameters differently. For example, it is considered good practice to not do weight decay on
                # BN & bias parameters. TODO: process the module tree instead of the parameter tree to accomplish the
                # same thing, but in a more effective way.
                if k.endswith(".bias"):
                    v.is_bias = True
                if k.endswith(".weight"):
                    v.is_weight = True
                if ".bn" in k or '.batchnorm' in k or '.bnorm' in k:
                    v.is_bn = True
                # Some models can specify some parameters to be in different groups.
                param_group = "default"
                if hasattr(v, 'PARAM_GROUP'):
                    if v.PARAM_GROUP in optim_params.keys():
                        param_group = v.PARAM_GROUP
                    else:
                        logger.warning(f'Model specifies a custom param group {v.PARAM_GROUP} which is not configured. '
                                       f'The same LR will be used for all parameters.')

                if v.requires_grad:
                    optim_params[param_group]['params'].append(v)
                else:
                    if self.env['rank'] <= 0:
                        logger.warning('Params [{:s}] will not optimize.'.format(k))

            if 'optimizer' not in self.step_opt.keys() or self.step_opt['optimizer'] == 'adam':
                opt = torch.optim.Adam(list(optim_params.values()),
                                       weight_decay=opt_config['weight_decay'],
                                       betas=(opt_config['beta1'], opt_config['beta2']))
            elif self.step_opt['optimizer'] == 'adamw':
                opt = torch.optim.AdamW(list(optim_params.values()),
                                       weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
                                       betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
            elif self.step_opt['optimizer'] == 'lars':
                from trainer.optimizers.larc import LARC
                from trainer.optimizers.sgd import SGDNoBiasMomentum
                optSGD = SGDNoBiasMomentum(list(optim_params.values()), lr=opt_config['lr'], momentum=opt_config['momentum'],
                                           weight_decay=opt_config['weight_decay'])
                opt = LARC(optSGD, trust_coefficient=opt_config['lars_coefficient'])
            elif self.step_opt['optimizer'] == 'sgd':
                from torch.optim import SGD
                opt = SGD(list(optim_params.values()), lr=opt_config['lr'], momentum=opt_config['momentum'], weight_decay=opt_config['weight_decay'])
            opt._config = opt_config  # This is a bit seedy, but we will need these configs later.
            opt._config['network'] = net_name
            self.optimizers.append(opt)

    # Returns all optimizers used in this step.
    def get_optimizers(self):
        assert self.optimizers is not None
        return self.optimizers

    # Returns optimizers which are opting in for default LR scheduling.
    def get_optimizers_with_default_scheduler(self):
        assert self.optimizers is not None
        return self.optimizers

    # Returns the names of the networks this step will train. Other networks will be frozen.
    def get_networks_trained(self):
        if isinstance(self.step_opt['training'], list):
            return self.step_opt['training']
        else:
            return [self.step_opt['training']]

    def get_training_network_name(self):
        if isinstance(self.step_opt['training'], list):
            return self.step_opt['training'][0]
        else:
            return self.step_opt['training']

    # Performs all forward and backward passes for this step given an input state. All input states are lists of
    # chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later
    # steps might use. These tensors are automatically detached and accumulated into chunks.
    def do_forward_backward(self, state, grad_accum_step, amp_loss_id, train=True, no_ddp_sync=False):
        local_state = {}  # <-- Will store the entire local state to be passed to injectors & losses.
        new_state = {}  # <-- Will store state values created by this step for returning to ExtensibleTrainer.
        for k, v in state.items():
            local_state[k] = v[grad_accum_step]
        local_state['train_nets'] = str(self.get_networks_trained())

        # Some losses compute backward() internally. Accommodate this by stashing the amp_loss_id in env.
        self.env['amp_loss_id'] = amp_loss_id
        self.env['current_step_optimizers'] = self.optimizers
        self.env['training'] = train

        # Inject in any extra dependencies.
        for inj in self.injectors:
            # Don't do injections tagged with eval unless we are not in train mode.
            if train and 'eval' in inj.opt.keys() and inj.opt['eval']:
                continue
            # Likewise, don't do injections tagged with train unless we are not in eval.
            if not train and 'train' in inj.opt.keys() and inj.opt['train']:
                continue
            # Don't do injections tagged with 'after' or 'before' when we are out of spec.
            if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \
               'before' in inj.opt.keys() and self.env['step'] > inj.opt['before'] or \
               'every' in inj.opt.keys() and self.env['step'] % inj.opt['every'] != 0:
                continue
            if 'no_accum' in inj.opt.keys() and grad_accum_step > 0:
                continue
            training_net = self.get_network_for_name(self.step_opt['training'])
            if no_ddp_sync and hasattr(training_net, 'no_sync'):
                with training_net.no_sync():
                    injected = inj(local_state)
            else:
                injected = inj(local_state)
            local_state.update(injected)
            new_state.update(injected)

        if train and len(self.losses) > 0:
            # Finally, compute the losses.
            total_loss = 0
            for loss_name, loss in self.losses.items():
                # Some losses only activate after a set number of steps. For example, proto-discriminator losses can
                # be very disruptive to a generator.
                if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step'] or \
                   'before' in loss.opt.keys() and self.env['step'] > loss.opt['before'] or \
                   'every' in loss.opt.keys() and self.env['step'] % loss.opt['every'] != 0:
                    continue
                if loss.is_stateful():
                    l, lstate = loss(self.get_network_for_name(self.step_opt['training']), local_state)
                    local_state.update(lstate)
                    new_state.update(lstate)
                else:
                    l = loss(self.get_network_for_name(self.step_opt['training']), local_state)
                total_loss += l * self.weights[loss_name]
                # Record metrics.
                if isinstance(l, torch.Tensor):
                    self.loss_accumulator.add_loss(loss_name, l)
                for n, v in loss.extra_metrics():
                    self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v)
                    loss.clear_metrics()

            # In some cases, the loss could not be set (e.g. all losses have 'after')
            if isinstance(total_loss, torch.Tensor):
                self.loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss)
                reset_required = total_loss < self.min_total_loss

                # Scale the loss down by the accumulation factor.
                total_loss = total_loss / self.env['mega_batch_factor']

                # Get dem grads!
                self.scaler.scale(total_loss).backward()

                if reset_required:
                    # You might be scratching your head at this. Why would you zero grad as opposed to not doing a
                    # backwards? Because DDP uses the backward() pass as a synchronization point and there is not a good
                    # way to simply bypass backward. If you want a more efficient way to specify a min_loss, use or
                    # implement it at the loss level.
                    self.get_network_for_name(self.step_opt['training']).zero_grad()
                    self.loss_accumulator.increment_metric("%s_skipped_steps" % (self.get_training_network_name(),))

                self.grads_generated = True

        # Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
        # we must release the gradients.
        new_state = recursively_detach(new_state)
        return new_state

    # Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps()
    # all self.optimizers.
    def do_step(self, step):
        if not self.grads_generated:
            return
        self.grads_generated = False
        for opt in self.optimizers:
            # Optimizers can be opted out in the early stages of training.
            after = opt._config['after'] if 'after' in opt._config.keys() else 0
            after_network = self.opt['networks'][opt._config['network']]['after'] if 'after' in self.opt['networks'][opt._config['network']].keys() else 0
            after = max(after, after_network)
            if self.env['step'] < after:
                continue
            before = opt._config['before'] if 'before' in opt._config.keys() else -1
            if before != -1 and self.env['step'] > before:
                continue

            nan_found = False
            if self.check_grads_for_nan:
                for pg in opt.param_groups:
                    for p in pg['params']:
                        if not torch.isfinite(p.grad).any():
                            nan_found = True
                            break
                    if nan_found:
                        break
                if nan_found:
                    print("NaN found in grads. Throwing this step out.")
                    self.nan_counter += 1
                else:
                    self.nan_counter = 0

            if self.clip_grad_eps is not None:
                for pg in opt.param_groups:
                    grad_norm = torch.nn.utils.clip_grad_norm_(pg['params'], self.clip_grad_eps)
                    if torch.isnan(grad_norm):
                        nan_found = True
                        self.nan_counter += 1

            if not nan_found:
                self.scaler.step(opt)
                self.scaler.update()

    def get_metrics(self):
        return self.loss_accumulator.as_dict()
Exemplo n.º 22
0
    def __init__(
        self,
        game: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        train_data: DataLoader,
        optimizer_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
        validation_data: Optional[DataLoader] = None,
        device: torch.device = None,
        callbacks: Optional[List[Callback]] = None,
        grad_norm: float = None,
        aggregate_interaction_logs: bool = True,
    ):
        """
        :param game: A nn.Module that implements forward(); it is expected that forward returns a tuple of (loss, d),
            where loss is differentiable loss to be minimized and d is a dictionary (potentially empty) with auxiliary
            metrics that would be aggregated and reported
        :param optimizer: An instance of torch.optim.Optimizer
        :param optimizer_scheduler: An optimizer scheduler to adjust lr throughout training
        :param train_data: A DataLoader for the training set
        :param validation_data: A DataLoader for the validation set (can be None)
        :param device: A torch.device on which to tensors should be stored
        :param callbacks: A list of egg.core.Callback objects that can encapsulate monitoring or checkpointing
        """
        self.game = game
        self.optimizer = optimizer
        self.optimizer_scheduler = optimizer_scheduler
        self.train_data = train_data
        self.validation_data = validation_data
        common_opts = get_opts()
        self.validation_freq = common_opts.validation_freq
        self.device = common_opts.device if device is None else device

        self.should_stop = False
        self.start_epoch = 0  # Can be overwritten by checkpoint loader
        self.callbacks = callbacks if callbacks else []
        self.grad_norm = grad_norm
        self.aggregate_interaction_logs = aggregate_interaction_logs

        self.update_freq = common_opts.update_freq

        if common_opts.load_from_checkpoint is not None:
            print(
                f"# Initializing model, trainer, and optimizer from {common_opts.load_from_checkpoint}"
            )
            self.load_from_checkpoint(common_opts.load_from_checkpoint)

        self.distributed_context = common_opts.distributed_context
        if self.distributed_context.is_distributed:
            print("# Distributed context: ", self.distributed_context)

        if self.distributed_context.is_leader and not any(
            isinstance(x, CheckpointSaver) for x in self.callbacks
        ):
            if common_opts.preemptable:
                assert (
                    common_opts.checkpoint_dir
                ), "checkpointing directory has to be specified"
                d = get_preemptive_checkpoint_dir(common_opts.checkpoint_dir)
                self.checkpoint_path = d
                self.load_from_latest(d)
            else:
                self.checkpoint_path = (
                    None
                    if common_opts.checkpoint_dir is None
                    else pathlib.Path(common_opts.checkpoint_dir)
                )

            if self.checkpoint_path:
                checkpointer = CheckpointSaver(
                    checkpoint_path=self.checkpoint_path,
                    checkpoint_freq=common_opts.checkpoint_freq,
                )
                self.callbacks.append(checkpointer)

        if self.distributed_context.is_leader and common_opts.tensorboard:
            assert (
                common_opts.tensorboard_dir
            ), "tensorboard directory has to be specified"
            tensorboard_logger = TensorboardLogger()
            self.callbacks.append(tensorboard_logger)

        if self.callbacks is None:
            self.callbacks = [
                ConsoleLogger(print_train_loss=False, as_json=False),
            ]

        if self.distributed_context.is_distributed:
            device_id = self.distributed_context.local_rank
            torch.cuda.set_device(device_id)
            self.game.to(device_id)

            # NB: here we are doing something that is a bit shady:
            # 1/ optimizer was created outside of the Trainer instance, so we don't really know
            #    what parameters it optimizes. If it holds something what is not within the Game instance
            #    then it will not participate in distributed training
            # 2/ if optimizer only holds a subset of Game parameters, it works, but somewhat non-documentedly.
            #    In fact, optimizer would hold parameters of non-DistributedDataParallel version of the Game. The
            #    forward/backward calls, however, would happen on the DistributedDataParallel wrapper.
            #    This wrapper would sync gradients of the underlying tensors - which are the ones that optimizer
            #    holds itself.  As a result it seems to work, but only because DDP doesn't take any tensor ownership.

            self.game = torch.nn.parallel.DistributedDataParallel(
                self.game,
                device_ids=[device_id],
                output_device=device_id,
                find_unused_parameters=True,
            )
            self.optimizer.state = move_to(self.optimizer.state, device_id)

        else:
            self.game.to(self.device)
            # NB: some optimizers pre-allocate buffers before actually doing any steps
            # since model is placed on GPU within Trainer, this leads to having optimizer's state and model parameters
            # on different devices. Here, we protect from that by moving optimizer's internal state to the proper device
            self.optimizer.state = move_to(self.optimizer.state, self.device)

        if common_opts.fp16:
            self.scaler = GradScaler()
        else:
            self.scaler = None
Exemplo n.º 23
0
def invert(
    data_loader,
    loss_fn,
    optimizer,
    steps=10,
    scheduler=None,
    use_amp=False,
    grad_norm_fn=None,
    callback_fn=None,
    plot=False,
    fig_path=None,
    track_per_batch=False,
    track_grad_norm=False,
):

    assert valid_data_loader(
        data_loader), f"invalid data_loader: {data_loader}"

    params = sum((p_group['params'] for p_group in optimizer.param_groups), [])
    lrs = [p_group['lr'] for p_group in optimizer.param_groups]
    device = params[0].device
    USE_AMP = (device.type == 'cuda') and use_amp
    if USE_AMP:
        scaler = GradScaler()

    num_batches = len(data_loader)
    track_len = steps * num_batches if track_per_batch else steps
    metrics = pd.DataFrame({'step': [None] * track_len})

    def process_result(res):
        if isinstance(res, dict):
            loss = res['loss']
            info = res
            for k, v in info.items():
                info[k] = v.item() if isinstance(v, torch.Tensor) else v
        elif isinstance(res, tuple):
            loss, info = res
        else:
            loss = res
            info = {'loss': loss.item()}
        return loss, info

    print(flush=True)

    # if callback_fn:
    #     callback_fn(0, None)

    with tqdmEpoch(steps, num_batches) as pbar:
        for epoch in range(steps):
            for batch_i, data in enumerate(data_loader):

                optimizer.zero_grad()

                if USE_AMP:
                    with autocast():
                        res = loss_fn(data)
                    loss, info = process_result(res)
                    scaler.scale(loss).backward()
                    grad_scale = scaler.get_scale()
                else:
                    res = loss_fn(data)
                    loss, info = process_result(res)
                    loss.backward()
                    grad_scale = 1

                if USE_AMP:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()

                if scheduler is not None:
                    scheduler.step(loss)

                if track_grad_norm or grad_norm_fn:
                    # XXX: probably shouldn't multiply with lr
                    total_norm = torch.norm(
                        torch.stack([
                            p.grad.detach().norm() / grad_scale  # * lr
                            for p, lr in zip(params, lrs)
                        ])).item()

                    if grad_norm_fn:
                        rescale_coef = grad_norm_fn(total_norm) / total_norm
                        for param in params:
                            param.grad.detach().mul_(rescale_coef)

                    info['|grad|'] = total_norm

                pbar.set_postfix(**{
                    k: v
                    for k, v in info.items() if ']' not in k
                },
                                 refresh=False)
                pbar.update()

                if track_per_batch:
                    batch_total = epoch * num_batches + batch_i
                    step = batch_total
                    # step = epoch + (batch_i + 1) / num_batches
                else:
                    step = epoch
                    # step = epoch + 1 + batch_i / num_batches

                for k, v in info.items():
                    if k not in metrics:  # add new column
                        metrics[k] = None
                    if metrics[k][step] is None:
                        metrics[k][step] = v
                    else:
                        metrics[k][step] += v

                if not track_per_batch and batch_i == 0:
                    metrics['step'][epoch] = epoch + 1
                if track_per_batch:
                    metrics['step'][batch_total] = (batch_total +
                                                    1) / num_batches
                # batch end

            if not track_per_batch:
                for k, v in metrics.items():
                    if k != 'step':
                        metrics[k][epoch] /= num_batches

            if callback_fn:
                callback_fn(epoch + 1, metrics.iloc[step])
            # epoch end

    print(flush=True)

    if plot and steps > 1:
        plot_metrics(metrics, fig_path=fig_path, smoothing=0)

    return metrics
Exemplo n.º 24
0
def main(args):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    # config = []

    device = torch.device('cuda')

    model = None
    if args.arch == "UNet":
        model = UNet(args).to(device)
    else:
        raise ("architectures other than Unet hasn't been added!!")

    # update_lrs = nn.Parameter(args.update_lr*torch.ones(self.update_step, len(self.net.vars)), requires_grad=True)
    model.optimizer = optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 eps=1e-7,
                                 amsgrad=True,
                                 weight_decay=args.weight_decay)
    model.lr_scheduler = optim.lr_scheduler.ExponentialLR(
        model.optimizer, args.exp_decay)

    tmp = filter(lambda x: x.requires_grad, model.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(model)

    #for name, param in model.named_parameters():
    #    print(name, param.size())
    print('Total trainable tensors:', num, flush=True)

    SUMMARY_INTERVAL = 5
    TEST_PRINT_INTERVAL = SUMMARY_INTERVAL * 5
    ITER_SAVE_INTERVAL = 300
    EPOCH_SAVE_INTERVAL = 5

    model_path = args.model_saving_path + args.model_name + "_batch_size_" + str(
        args.batch_size) + "_lr_" + str(args.lr)
    if not os.path.isdir(model_path):
        os.mkdir(model_path)

    ds = SimulationDataset(args.data_folder,
                           total_sample_number=args.total_sample_number)
    means = [
        1e-3, 1e-3
    ]  #ds.means; #[Hy_meanR, Hy_meanI, Ex_meanR, Ex_meanI, Ez_meanR, Ez_meanI];
    print("means: ", means)
    torch.manual_seed(42)
    train_ds, test_ds = random_split(
        ds,
        [int(0.9 * len(ds)), len(ds) - int(0.9 * len(ds))])

    #print("total training samples: %d, total test samples: %d" % (len(train_ds), len(test_ds)), flush=True)
    train_loader = DataLoader(train_ds,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=0)
    test_loader = DataLoader(test_ds,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=0)

    train_mean = 0
    test_mean = 0
    # first get the mean-absolute-field value:
    for sample_batched in train_loader:
        train_mean += torch.mean(torch.abs(sample_batched["field"]))
    for sample_batched in test_loader:
        test_mean += torch.mean(torch.abs(sample_batched["field"]))
    train_mean /= len(train_loader)
    test_mean /= len(test_loader)

    print(
        "total training samples: %d, total test samples: %d, train_abs_mean: %f, test_abs_mean: %f"
        % (len(train_ds), len(test_ds), train_mean, test_mean),
        flush=True)

    # for visualizing the graph:
    #writer = SummaryWriter('runs/'+args.model_name)

    #test_input = None
    #for sample in train_loader:
    #    test_input = sample['structure']
    #    break
    #writer.add_graph(model, test_input.to(device))
    #writer.close()

    df = pd.DataFrame(columns=[
        'epoch', 'train_loss', 'train_phys_reg', 'test_loss', 'test_phys_reg'
    ])

    train_loss_history = []
    train_phys_reg_history = []

    test_loss_history = []
    test_phys_reg_history = []

    start_epoch = 0
    if (args.continue_train):
        print("Restoring weights from ",
              model_path + "/last_model.pt",
              flush=True)
        checkpoint = torch.load(model_path + "/last_model.pt")
        start_epoch = checkpoint['epoch']
        model = checkpoint['model']
        model.lr_scheduler = checkpoint['lr_scheduler']
        model.optimizer = checkpoint['optimizer']
        df = read_csv(model_path + '/' + 'df.csv')

    scaler = GradScaler()

    best_loss = 1e4
    last_epoch_data_loss = 1.0
    last_epoch_physical_loss = 1.0
    for step in range(start_epoch, args.epoch):
        print("epoch: ", step, flush=True)
        reg_norm = regConstScheduler(step, args, last_epoch_data_loss,
                                     last_epoch_physical_loss)
        # training
        for sample_batched in train_loader:
            model.optimizer.zero_grad()

            x_batch_train, y_batch_train = sample_batched['structure'].to(
                device), sample_batched['field'].to(device)
            with autocast():
                logits = model(x_batch_train, bn_training=True)
                #calculate the loss using the ground truth
                loss = model.loss_fn(logits, y_batch_train)
                logits = logits[:, :, 1:-1, :]
                # print("loss: ", loss, flush=True)

                # Calculate physical residue
                pattern = (x_batch_train * (n_Si - n_air) + n_air)**2
                # rescale the 0/1 pattern into dielectric constant
                pattern = torch.cat(
                    (torch.ones([pattern.shape[0], 1, 1, 256],
                                dtype=torch.float32,
                                device=device) * n_sub**2, pattern),
                    dim=2)
                #fields = logits; # predicted fields [Hy_R, Hy_I, Ex_R, Ex_I, Ez_R, Ez_I]
                fields = torch.cat((y_batch_train[:, :, 0:1, :], logits,
                                    y_batch_train[:, :, -1:, :]),
                                   dim=2)
                FD_Hy = H_to_H(-fields[:, 0] * means[0],
                               -fields[:, 1] * means[1], dL, omega, pattern)
                #phys_regR = 10*model.loss_fn(FD_Hy[:, 0]/means[0], logits[:, 0])/reg_norm;
                #phys_regI = 10*model.loss_fn(FD_Hy[:, 1]/means[1], logits[:, 1])/reg_norm;
                phys_regR = model.loss_fn(FD_Hy[:, 0] / means[0],
                                          logits[:, 0]) * reg_norm
                phys_regI = model.loss_fn(FD_Hy[:, 1] / means[1],
                                          logits[:, 1]) * reg_norm
                loss += phys_regR + phys_regI

                scaler.scale(loss).backward()
                scaler.step(model.optimizer)
                scaler.update()

                #loss.backward()
                #model.optimizer.step()

        #Save the weights at the end of each epoch
        checkpoint = {
            'epoch': step,
            'model': model,
            'optimizer': model.optimizer,
            'lr_scheduler': model.lr_scheduler
        }
        torch.save(checkpoint, model_path + "/last_model.pt")

        # evaluation
        train_loss = 0
        train_phys_reg = 0
        for sample_batched in train_loader:
            x_batch_train, y_batch_train = sample_batched['structure'].to(
                device), sample_batched['field'].to(device)

            with torch.no_grad():
                logits = model(x_batch_train, bn_training=False)
                loss = model.loss_fn(logits, y_batch_train)
                logits = logits[:, :, 1:-1, :]
                # Calculate physical residue
                pattern = (x_batch_train * (n_Si - n_air) + n_air)**2
                # rescale the 0/1 pattern into dielectric constant
                pattern = torch.cat(
                    (torch.ones([pattern.shape[0], 1, 1, 256],
                                dtype=torch.float32,
                                device=device) * n_sub**2, pattern),
                    dim=2)
                #fields = logits; # predicted fields [Hy_R, Hy_I, Ex_R, Ex_I, Ez_R, Ez_I]
                fields = torch.cat((y_batch_train[:, :, 0:1, :], logits,
                                    y_batch_train[:, :, -1:, :]),
                                   dim=2)
                FD_Hy = H_to_H(-fields[:, 0] * means[0],
                               -fields[:, 1] * means[1], dL, omega, pattern)
                #phys_regR = 10*model.loss_fn(FD_Hy[:, 0]/means[0], fields[:, 0, 1:-1, :])/reg_norm;
                #phys_regI = 10*model.loss_fn(FD_Hy[:, 1]/means[1], fields[:, 1, 1:-1, :])/reg_norm;
                phys_regR = model.loss_fn(FD_Hy[:, 0] / means[0],
                                          fields[:, 0, 1:-1, :]) * reg_norm
                phys_regI = model.loss_fn(FD_Hy[:, 1] / means[1],
                                          fields[:, 1, 1:-1, :]) * reg_norm

                #loss = loss + phys_reg1 + phys_reg2 + phys_reg3;
                train_loss += loss
                train_phys_reg += 0.5 * (phys_regR + phys_regI)

        train_loss /= len(train_loader)
        train_phys_reg /= len(train_loader)

        test_loss = 0
        test_phys_reg = 0
        for sample_batched in test_loader:
            x_batch_test, y_batch_test = sample_batched['structure'].to(
                device), sample_batched['field'].to(device)

            with torch.no_grad():
                logits = model(x_batch_test, bn_training=False)
                loss = model.loss_fn(logits, y_batch_test)
                logits = logits[:, :, 1:-1, :]
                # Calculate physical residue
                pattern = (x_batch_test * (n_Si - n_air) + n_air)**2
                # rescale the 0/1 pattern into dielectric constant
                pattern = torch.cat(
                    (torch.ones([pattern.shape[0], 1, 1, 256],
                                dtype=torch.float32,
                                device=device) * n_sub**2, pattern),
                    dim=2)
                #fields = logits; # predicted fields [Hy_R, Hy_I, Ex_R, Ex_I, Ez_R, Ez_I]
                fields = torch.cat(
                    (y_batch_test[:, :, 0:1, :], logits, y_batch_test[:, :,
                                                                      -1:, :]),
                    dim=2)
                FD_Hy = H_to_H(-fields[:, 0] * means[0],
                               -fields[:, 1] * means[1], dL, omega, pattern)
                #phys_regR = 10*model.loss_fn(FD_Hy[:, 0]/means[0], fields[:, 0, 1:-1, :])/reg_norm;
                #phys_regI = 10*model.loss_fn(FD_Hy[:, 1]/means[1], fields[:, 1, 1:-1, :])/reg_norm;
                phys_regR = model.loss_fn(FD_Hy[:, 0] / means[0],
                                          fields[:, 0, 1:-1, :])
                phys_regI = model.loss_fn(FD_Hy[:, 1] / means[1],
                                          fields[:, 1, 1:-1, :])

                test_loss += loss
                test_phys_reg += 0.5 * (phys_regR + phys_regI)

        test_loss /= len(test_loader)
        test_phys_reg /= len(test_loader)
        last_epoch_data_loss = test_loss
        last_epoch_physical_loss = test_phys_reg.detach().clone()
        test_phys_reg *= reg_norm

        print(
            'train loss: %.5f, train phys reg: %.5f, test loss: %.5f, test phys reg: %.5f, last_physical_loss: %.5f'
            % (train_loss, train_phys_reg, test_loss, test_phys_reg,
               last_epoch_physical_loss),
            flush=True)

        model.lr_scheduler.step()

        df = df.append(
            {
                'epoch': step + 1,
                'lr': str(model.lr_scheduler.get_last_lr()),
                'train_loss': train_loss.item(),
                'train_phys_reg': train_phys_reg.item(),
                'test_loss': test_loss.item(),
                'test_phys_reg': test_phys_reg.item(),
            },
            ignore_index=True)

        df.to_csv(model_path + '/' + 'df.csv', index=False)

        if (test_loss < best_loss):
            best_loss = test_loss
            checkpoint = {
                'epoch': step,
                'model': model,
                'optimizer': model.optimizer,
                'lr_scheduler': model.lr_scheduler
            }
            torch.save(checkpoint, model_path + "/best_model.pt")
Exemplo n.º 25
0
class NetworkTrainer(object):
    def __init__(self, deterministic=True, fp16=False):
        """
        A generic class that can train almost any neural network (RNNs excluded). It provides basic functionality such
        as the training loop, tracking of training and validation losses (and the target metric if you implement it)
        Training can be terminated early if the validation loss (or the target metric if implemented) do not improve
        anymore. This is based on a moving average (MA) of the loss/metric instead of the raw values to get more smooth
        results.

        What you need to override:
        - __init__
        - initialize
        - run_online_evaluation (optional)
        - finish_online_evaluation (optional)
        - validate
        - predict_test_case
        """
        self.fp16 = fp16
        self.amp_grad_scaler = None

        if deterministic:
            np.random.seed(12345)
            torch.manual_seed(12345)
            if torch.cuda.is_available():
                torch.cuda.manual_seed_all(12345)
            cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
        else:
            cudnn.deterministic = False
            torch.backends.cudnn.benchmark = True

        ################# SET THESE IN self.initialize() ###################################
        self.network: Tuple[SegmentationNetwork, nn.DataParallel] = None
        self.optimizer = None
        self.lr_scheduler = None
        self.tr_gen = self.val_gen = None
        self.was_initialized = False

        ################# SET THESE IN INIT ################################################
        self.output_folder = None
        self.fold = None
        self.loss = None
        self.dataset_directory = None

        ################# SET THESE IN LOAD_DATASET OR DO_SPLIT ############################
        self.dataset = None  # these can be None for inference mode
        self.dataset_tr = self.dataset_val = None  # do not need to be used, they just appear if you are using the suggested load_dataset_and_do_split

        ################# THESE DO NOT NECESSARILY NEED TO BE MODIFIED #####################
        self.patience = 50
        self.val_eval_criterion_alpha = 0.9  # alpha * old + (1-alpha) * new
        # if this is too low then the moving average will be too noisy and the training may terminate early. If it is
        # too high the training will take forever
        self.train_loss_MA_alpha = 0.93  # alpha * old + (1-alpha) * new
        self.train_loss_MA_eps = 5e-4  # new MA must be at least this much better (smaller)
        self.max_num_epochs = 1000
        self.num_batches_per_epoch = 250
        self.num_val_batches_per_epoch = 50
        self.also_val_in_tr_mode = False
        self.lr_threshold = 1e-6  # the network will not terminate training if the lr is still above this threshold

        ################# LEAVE THESE ALONE ################################################
        self.val_eval_criterion_MA = None
        self.train_loss_MA = None
        self.best_val_eval_criterion_MA = None
        self.best_MA_tr_loss_for_patience = None
        self.best_epoch_based_on_MA_tr_loss = None
        self.all_tr_losses = []
        self.all_val_losses = []
        self.all_val_losses_tr_mode = []
        self.all_val_eval_metrics = []  # does not have to be used
        self.epoch = 0
        self.log_file = None
        self.deterministic = deterministic

        self.use_progress_bar = False
        if 'nnunet_use_progress_bar' in os.environ.keys():
            self.use_progress_bar = bool(
                int(os.environ['nnunet_use_progress_bar']))

        ################# Settings for saving checkpoints ##################################
        self.save_every = 50
        self.save_latest_only = True  # if false it will not store/overwrite _latest but separate files each
        # time an intermediate checkpoint is created
        self.save_intermediate_checkpoints = True  # whether or not to save checkpoint_latest
        self.save_best_checkpoint = True  # whether or not to save the best checkpoint according to self.best_val_eval_criterion_MA
        self.save_final_checkpoint = True  # whether or not to save the final checkpoint

    @abstractmethod
    def initialize(self, training=True):
        """
        create self.output_folder

        modify self.output_folder if you are doing cross-validation (one folder per fold)

        set self.tr_gen and self.val_gen

        call self.initialize_network and self.initialize_optimizer_and_scheduler (important!)

        finally set self.was_initialized to True
        :param training:
        :return:
        """

    @abstractmethod
    def load_dataset(self):
        pass

    def do_split(self):
        """
        This is a suggestion for if your dataset is a dictionary (my personal standard)
        :return:
        """
        splits_file = join(self.dataset_directory, "splits_final.pkl")
        if not isfile(splits_file):
            self.print_to_log_file("Creating new split...")
            splits = []
            all_keys_sorted = np.sort(list(self.dataset.keys()))
            kfold = KFold(n_splits=5, shuffle=True, random_state=12345)
            for i, (train_idx,
                    test_idx) in enumerate(kfold.split(all_keys_sorted)):
                train_keys = np.array(all_keys_sorted)[train_idx]
                test_keys = np.array(all_keys_sorted)[test_idx]
                splits.append(OrderedDict())
                splits[-1]['train'] = train_keys
                splits[-1]['val'] = test_keys
            save_pickle(splits, splits_file)

        splits = load_pickle(splits_file)

        if self.fold == "all":
            tr_keys = val_keys = list(self.dataset.keys())
        else:
            tr_keys = splits[self.fold]['train']
            val_keys = splits[self.fold]['val']

        tr_keys.sort()
        val_keys.sort()

        self.dataset_tr = OrderedDict()
        for i in tr_keys:
            self.dataset_tr[i] = self.dataset[i]

        self.dataset_val = OrderedDict()
        for i in val_keys:
            self.dataset_val[i] = self.dataset[i]

    def plot_progress(self):
        """
        Should probably by improved
        :return:
        """
        try:
            font = {'weight': 'normal', 'size': 18}

            matplotlib.rc('font', **font)

            fig = plt.figure(figsize=(30, 24))
            ax = fig.add_subplot(111)
            ax2 = ax.twinx()

            x_values = list(range(self.epoch + 1))

            ax.plot(x_values,
                    self.all_tr_losses,
                    color='b',
                    ls='-',
                    label="loss_tr")

            ax.plot(x_values,
                    self.all_val_losses,
                    color='r',
                    ls='-',
                    label="loss_val, train=False")

            if len(self.all_val_losses_tr_mode) > 0:
                ax.plot(x_values,
                        self.all_val_losses_tr_mode,
                        color='g',
                        ls='-',
                        label="loss_val, train=True")
            if len(self.all_val_eval_metrics) == len(x_values):
                ax2.plot(x_values,
                         self.all_val_eval_metrics,
                         color='g',
                         ls='--',
                         label="evaluation metric")

            ax.set_xlabel("epoch")
            ax.set_ylabel("loss")
            ax2.set_ylabel("evaluation metric")
            ax.legend()
            ax2.legend(loc=9)

            fig.savefig(join(self.output_folder, "progress.png"))
            plt.close()
        except IOError:
            self.print_to_log_file("failed to plot: ", sys.exc_info())

    def print_to_log_file(self,
                          *args,
                          also_print_to_console=True,
                          add_timestamp=True):

        timestamp = time()
        dt_object = datetime.fromtimestamp(timestamp)

        if add_timestamp:
            args = ("%s:" % dt_object, *args)

        if self.log_file is None:
            maybe_mkdir_p(self.output_folder)
            timestamp = datetime.now()
            self.log_file = join(
                self.output_folder,
                "training_log_%d_%d_%d_%02.0d_%02.0d_%02.0d.txt" %
                (timestamp.year, timestamp.month, timestamp.day,
                 timestamp.hour, timestamp.minute, timestamp.second))
            with open(self.log_file, 'w') as f:
                f.write("Starting... \n")
        successful = False
        max_attempts = 5
        ctr = 0
        while not successful and ctr < max_attempts:
            try:
                with open(self.log_file, 'a+') as f:
                    for a in args:
                        f.write(str(a))
                        f.write(" ")
                    f.write("\n")
                successful = True
            except IOError:
                print(
                    "%s: failed to log: " % datetime.fromtimestamp(timestamp),
                    sys.exc_info())
                sleep(0.5)
                ctr += 1
        if also_print_to_console:
            print(*args)

    def save_checkpoint(self, fname, save_optimizer=True):
        start_time = time()
        state_dict = self.network.state_dict()
        for key in state_dict.keys():
            state_dict[key] = state_dict[key].cpu()
        lr_sched_state_dct = None
        if self.lr_scheduler is not None and hasattr(
                self.lr_scheduler, 'state_dict'
        ):  # not isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau):
            lr_sched_state_dct = self.lr_scheduler.state_dict()
            # WTF is this!?
            # for key in lr_sched_state_dct.keys():
            #    lr_sched_state_dct[key] = lr_sched_state_dct[key]
        if save_optimizer:
            optimizer_state_dict = self.optimizer.state_dict()
        else:
            optimizer_state_dict = None

        self.print_to_log_file("saving checkpoint...")
        save_this = {
            'epoch':
            self.epoch + 1,
            'state_dict':
            state_dict,
            'optimizer_state_dict':
            optimizer_state_dict,
            'lr_scheduler_state_dict':
            lr_sched_state_dct,
            'plot_stuff':
            (self.all_tr_losses, self.all_val_losses,
             self.all_val_losses_tr_mode, self.all_val_eval_metrics),
            'best_stuff': (self.best_epoch_based_on_MA_tr_loss,
                           self.best_MA_tr_loss_for_patience,
                           self.best_val_eval_criterion_MA)
        }
        if self.amp_grad_scaler is not None:
            save_this['amp_grad_scaler'] = self.amp_grad_scaler.state_dict()

        torch.save(save_this, fname)
        self.print_to_log_file("done, saving took %.2f seconds" %
                               (time() - start_time))

    def load_best_checkpoint(self, train=True):
        if self.fold is None:
            raise RuntimeError(
                "Cannot load best checkpoint if self.fold is None")
        if isfile(join(self.output_folder, "model_best.model")):
            self.load_checkpoint(join(self.output_folder, "model_best.model"),
                                 train=train)
        else:
            self.print_to_log_file(
                "WARNING! model_best.model does not exist! Cannot load best checkpoint. Falling "
                "back to load_latest_checkpoint")
            self.load_latest_checkpoint(train)

    def load_latest_checkpoint(self, train=True):
        if isfile(join(self.output_folder, "model_final_checkpoint.model")):
            return self.load_checkpoint(join(self.output_folder,
                                             "model_final_checkpoint.model"),
                                        train=train)
        if isfile(join(self.output_folder, "model_latest.model")):
            return self.load_checkpoint(join(self.output_folder,
                                             "model_latest.model"),
                                        train=train)
        if isfile(join(self.output_folder, "model_best.model")):
            return self.load_best_checkpoint(train)
        raise RuntimeError("No checkpoint found")

    def load_checkpoint(self, fname, train=True):
        self.print_to_log_file("loading checkpoint", fname, "train=", train)
        if not self.was_initialized:
            self.initialize(train)
        # saved_model = torch.load(fname, map_location=torch.device('cuda', torch.cuda.current_device()))
        saved_model = torch.load(fname, map_location=torch.device('cpu'))
        self.load_checkpoint_ram(saved_model, train)

    @abstractmethod
    def initialize_network(self):
        """
        initialize self.network here
        :return:
        """
        pass

    @abstractmethod
    def initialize_optimizer_and_scheduler(self):
        """
        initialize self.optimizer and self.lr_scheduler (if applicable) here
        :return:
        """
        pass

    def load_checkpoint_ram(self, checkpoint, train=True):
        """
        used for if the checkpoint is already in ram
        :param checkpoint:
        :param train:
        :return:
        """
        if not self.was_initialized:
            self.initialize(train)

        new_state_dict = OrderedDict()
        curr_state_dict_keys = list(self.network.state_dict().keys())
        # if state dict comes form nn.DataParallel but we use non-parallel model here then the state dict keys do not
        # match. Use heuristic to make it match
        for k, value in checkpoint['state_dict'].items():
            key = k
            if key not in curr_state_dict_keys and key.startswith('module.'):
                key = key[7:]
            new_state_dict[key] = value

        if self.fp16:
            self._maybe_init_amp()
            if 'amp_grad_scaler' in checkpoint.keys():
                self.amp_grad_scaler.load_state_dict(
                    checkpoint['amp_grad_scaler'])

        self.network.load_state_dict(new_state_dict)
        self.epoch = checkpoint['epoch']
        if train:
            optimizer_state_dict = checkpoint['optimizer_state_dict']
            if optimizer_state_dict is not None:
                self.optimizer.load_state_dict(optimizer_state_dict)

            if self.lr_scheduler is not None and hasattr(
                    self.lr_scheduler, 'load_state_dict'
            ) and checkpoint['lr_scheduler_state_dict'] is not None:
                self.lr_scheduler.load_state_dict(
                    checkpoint['lr_scheduler_state_dict'])

            if issubclass(self.lr_scheduler.__class__, _LRScheduler):
                self.lr_scheduler.step(self.epoch)

        self.all_tr_losses, self.all_val_losses, self.all_val_losses_tr_mode, self.all_val_eval_metrics = checkpoint[
            'plot_stuff']

        # load best loss (if present)
        if 'best_stuff' in checkpoint.keys():
            self.best_epoch_based_on_MA_tr_loss, self.best_MA_tr_loss_for_patience, self.best_val_eval_criterion_MA = checkpoint[
                'best_stuff']

        # after the training is done, the epoch is incremented one more time in my old code. This results in
        # self.epoch = 1001 for old trained models when the epoch is actually 1000. This causes issues because
        # len(self.all_tr_losses) = 1000 and the plot function will fail. We can easily detect and correct that here
        if self.epoch != len(self.all_tr_losses):
            self.print_to_log_file(
                "WARNING in loading checkpoint: self.epoch != len(self.all_tr_losses). This is "
                "due to an old bug and should only appear when you are loading old models. New "
                "models should have this fixed! self.epoch is now set to len(self.all_tr_losses)"
            )
            self.epoch = len(self.all_tr_losses)
            self.all_tr_losses = self.all_tr_losses[:self.epoch]
            self.all_val_losses = self.all_val_losses[:self.epoch]
            self.all_val_losses_tr_mode = self.all_val_losses_tr_mode[:self.
                                                                      epoch]
            self.all_val_eval_metrics = self.all_val_eval_metrics[:self.epoch]

        self._maybe_init_amp()

    def _maybe_init_amp(self):
        if self.fp16 and self.amp_grad_scaler is None and torch.cuda.is_available(
        ):
            self.amp_grad_scaler = GradScaler()

    def plot_network_architecture(self):
        """
        can be implemented (see nnUNetTrainer) but does not have to. Not implemented here because it imposes stronger
        assumptions on the presence of class variables
        :return:
        """
        pass

    def run_training(self):
        _ = self.tr_gen.next()
        _ = self.val_gen.next()

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        self._maybe_init_amp()

        maybe_mkdir_p(self.output_folder)
        self.plot_network_architecture()

        if cudnn.benchmark and cudnn.deterministic:
            warn(
                "torch.backends.cudnn.deterministic is True indicating a deterministic training is desired. "
                "But torch.backends.cudnn.benchmark is True as well and this will prevent deterministic training! "
                "If you want deterministic then set benchmark=False")

        if not self.was_initialized:
            self.initialize(True)

        while self.epoch < self.max_num_epochs:
            self.print_to_log_file("\nepoch: ", self.epoch)
            epoch_start_time = time()
            train_losses_epoch = []

            # train one epoch
            self.network.train()

            if self.use_progress_bar:
                with trange(self.num_batches_per_epoch) as tbar:
                    for b in tbar:
                        tbar.set_description("Epoch {}/{}".format(
                            self.epoch + 1, self.max_num_epochs))

                        l = self.run_iteration(self.tr_gen, True)

                        tbar.set_postfix(loss=l)
                        train_losses_epoch.append(l)
            else:
                for _ in range(self.num_batches_per_epoch):
                    l = self.run_iteration(self.tr_gen, True)
                    train_losses_epoch.append(l)

            self.all_tr_losses.append(np.mean(train_losses_epoch))
            self.print_to_log_file("train loss : %.4f" %
                                   self.all_tr_losses[-1])

            with torch.no_grad():
                # validation with train=False
                self.network.eval()
                val_losses = []
                for b in range(self.num_val_batches_per_epoch):
                    l = self.run_iteration(self.val_gen, False, True)
                    val_losses.append(l)
                self.all_val_losses.append(np.mean(val_losses))
                self.print_to_log_file("validation loss: %.4f" %
                                       self.all_val_losses[-1])

                if self.also_val_in_tr_mode:
                    self.network.train()
                    # validation with train=True
                    val_losses = []
                    for b in range(self.num_val_batches_per_epoch):
                        l = self.run_iteration(self.val_gen, False)
                        val_losses.append(l)
                    self.all_val_losses_tr_mode.append(np.mean(val_losses))
                    self.print_to_log_file(
                        "validation loss (train=True): %.4f" %
                        self.all_val_losses_tr_mode[-1])

            self.update_train_loss_MA(
            )  # needed for lr scheduler and stopping of training

            continue_training = self.on_epoch_end()

            epoch_end_time = time()

            if not continue_training:
                # allows for early stopping
                break

            self.epoch += 1
            self.print_to_log_file("This epoch took %f s\n" %
                                   (epoch_end_time - epoch_start_time))

        self.epoch -= 1  # if we don't do this we can get a problem with loading model_final_checkpoint.

        if self.save_final_checkpoint:
            self.save_checkpoint(
                join(self.output_folder, "model_final_checkpoint.model"))
        # now we can delete latest as it will be identical with final
        if isfile(join(self.output_folder, "model_latest.model")):
            os.remove(join(self.output_folder, "model_latest.model"))
        if isfile(join(self.output_folder, "model_latest.model.pkl")):
            os.remove(join(self.output_folder, "model_latest.model.pkl"))

    def maybe_update_lr(self):
        # maybe update learning rate
        if self.lr_scheduler is not None:
            assert isinstance(
                self.lr_scheduler,
                (lr_scheduler.ReduceLROnPlateau, lr_scheduler._LRScheduler))

            if isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau):
                # lr scheduler is updated with moving average val loss. should be more robust
                self.lr_scheduler.step(self.train_loss_MA)
            else:
                self.lr_scheduler.step(self.epoch + 1)
        self.print_to_log_file("lr is now (scheduler) %s" %
                               str(self.optimizer.param_groups[0]['lr']))

    def maybe_save_checkpoint(self):
        """
        Saves a checkpoint every save_ever epochs.
        :return:
        """
        if self.save_intermediate_checkpoints and (self.epoch % self.save_every
                                                   == (self.save_every - 1)):
            self.print_to_log_file("saving scheduled checkpoint file...")
            if not self.save_latest_only:
                self.save_checkpoint(
                    join(self.output_folder,
                         "model_ep_%03.0d.model" % (self.epoch + 1)))
            self.save_checkpoint(join(self.output_folder,
                                      "model_latest.model"))
            self.print_to_log_file("done")

    def update_eval_criterion_MA(self):
        """
        If self.all_val_eval_metrics is unused (len=0) then we fall back to using -self.all_val_losses for the MA to determine early stopping
        (not a minimization, but a maximization of a metric and therefore the - in the latter case)
        :return:
        """
        if self.val_eval_criterion_MA is None:
            if len(self.all_val_eval_metrics) == 0:
                self.val_eval_criterion_MA = -self.all_val_losses[-1]
            else:
                self.val_eval_criterion_MA = self.all_val_eval_metrics[-1]
        else:
            if len(self.all_val_eval_metrics) == 0:
                """
                We here use alpha * old - (1 - alpha) * new because new in this case is the vlaidation loss and lower
                is better, so we need to negate it.
                """
                self.val_eval_criterion_MA = self.val_eval_criterion_alpha * self.val_eval_criterion_MA - (
                        1 - self.val_eval_criterion_alpha) * \
                                             self.all_val_losses[-1]
            else:
                self.val_eval_criterion_MA = self.val_eval_criterion_alpha * self.val_eval_criterion_MA + (
                        1 - self.val_eval_criterion_alpha) * \
                                             self.all_val_eval_metrics[-1]

    def manage_patience(self):
        # update patience
        continue_training = True
        if self.patience is not None:
            # if best_MA_tr_loss_for_patience and best_epoch_based_on_MA_tr_loss were not yet initialized,
            # initialize them
            if self.best_MA_tr_loss_for_patience is None:
                self.best_MA_tr_loss_for_patience = self.train_loss_MA

            if self.best_epoch_based_on_MA_tr_loss is None:
                self.best_epoch_based_on_MA_tr_loss = self.epoch

            if self.best_val_eval_criterion_MA is None:
                self.best_val_eval_criterion_MA = self.val_eval_criterion_MA

            # check if the current epoch is the best one according to moving average of validation criterion. If so
            # then save 'best' model
            # Do not use this for validation. This is intended for test set prediction only.
            #self.print_to_log_file("current best_val_eval_criterion_MA is %.4f0" % self.best_val_eval_criterion_MA)
            #self.print_to_log_file("current val_eval_criterion_MA is %.4f" % self.val_eval_criterion_MA)

            if self.val_eval_criterion_MA > self.best_val_eval_criterion_MA:
                self.best_val_eval_criterion_MA = self.val_eval_criterion_MA
                #self.print_to_log_file("saving best epoch checkpoint...")
                if self.save_best_checkpoint:
                    self.save_checkpoint(
                        join(self.output_folder, "model_best.model"))

            # Now see if the moving average of the train loss has improved. If yes then reset patience, else
            # increase patience
            if self.train_loss_MA + self.train_loss_MA_eps < self.best_MA_tr_loss_for_patience:
                self.best_MA_tr_loss_for_patience = self.train_loss_MA
                self.best_epoch_based_on_MA_tr_loss = self.epoch
                #self.print_to_log_file("New best epoch (train loss MA): %03.4f" % self.best_MA_tr_loss_for_patience)
            else:
                pass
                #self.print_to_log_file("No improvement: current train MA %03.4f, best: %03.4f, eps is %03.4f" %
                #                       (self.train_loss_MA, self.best_MA_tr_loss_for_patience, self.train_loss_MA_eps))

            # if patience has reached its maximum then finish training (provided lr is low enough)
            if self.epoch - self.best_epoch_based_on_MA_tr_loss > self.patience:
                if self.optimizer.param_groups[0]['lr'] > self.lr_threshold:
                    #self.print_to_log_file("My patience ended, but I believe I need more time (lr > 1e-6)")
                    self.best_epoch_based_on_MA_tr_loss = self.epoch - self.patience // 2
                else:
                    #self.print_to_log_file("My patience ended")
                    continue_training = False
            else:
                pass
                #self.print_to_log_file(
                #    "Patience: %d/%d" % (self.epoch - self.best_epoch_based_on_MA_tr_loss, self.patience))

        return continue_training

    def on_epoch_end(self):
        self.finish_online_evaluation(
        )  # does not have to do anything, but can be used to update self.all_val_eval_
        # metrics

        self.plot_progress()

        self.maybe_update_lr()

        self.maybe_save_checkpoint()

        self.update_eval_criterion_MA()

        continue_training = self.manage_patience()
        return continue_training

    def update_train_loss_MA(self):
        if self.train_loss_MA is None:
            self.train_loss_MA = self.all_tr_losses[-1]
        else:
            self.train_loss_MA = self.train_loss_MA_alpha * self.train_loss_MA + (1 - self.train_loss_MA_alpha) * \
                                 self.all_tr_losses[-1]

    def run_iteration(self,
                      data_generator,
                      do_backprop=True,
                      run_online_evaluation=False):
        data_dict = next(data_generator)
        data = data_dict['data']
        target = data_dict['target']

        data = maybe_to_torch(data)
        target = maybe_to_torch(target)

        if torch.cuda.is_available():
            data = to_cuda(data)
            target = to_cuda(target)

        self.optimizer.zero_grad()

        if self.fp16:
            with autocast():
                output = self.network(data)
                del data
                l = self.loss(output, target)

            if do_backprop:
                self.amp_grad_scaler.scale(l).backward()
                self.amp_grad_scaler.step(self.optimizer)
                self.amp_grad_scaler.update()
        else:
            output = self.network(data)
            del data
            l = self.loss(output, target)

            if do_backprop:
                l.backward()
                self.optimizer.step()

        if run_online_evaluation:
            self.run_online_evaluation(output, target)

        del target

        return l.detach().cpu().numpy()

    def run_online_evaluation(self, *args, **kwargs):
        """
        Can be implemented, does not have to
        :param output_torch:
        :param target_npy:
        :return:
        """
        pass

    def finish_online_evaluation(self):
        """
        Can be implemented, does not have to
        :return:
        """
        pass

    @abstractmethod
    def validate(self, *args, **kwargs):
        pass

    def find_lr(self,
                num_iters=1000,
                init_value=1e-6,
                final_value=10.,
                beta=0.98):
        """
        stolen and adapted from here: https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html
        :param num_iters:
        :param init_value:
        :param final_value:
        :param beta:
        :return:
        """
        import math
        self._maybe_init_amp()
        mult = (final_value / init_value)**(1 / num_iters)
        lr = init_value
        self.optimizer.param_groups[0]['lr'] = lr
        avg_loss = 0.
        best_loss = 0.
        losses = []
        log_lrs = []

        for batch_num in range(1, num_iters + 1):
            # +1 because this one here is not designed to have negative loss...
            loss = self.run_iteration(
                self.tr_gen, do_backprop=True,
                run_online_evaluation=False).data.item() + 1

            # Compute the smoothed loss
            avg_loss = beta * avg_loss + (1 - beta) * loss
            smoothed_loss = avg_loss / (1 - beta**batch_num)

            # Stop if the loss is exploding
            if batch_num > 1 and smoothed_loss > 4 * best_loss:
                break

            # Record the best loss
            if smoothed_loss < best_loss or batch_num == 1:
                best_loss = smoothed_loss

            # Store the values
            losses.append(smoothed_loss)
            log_lrs.append(math.log10(lr))

            # Update the lr for the next step
            lr *= mult
            self.optimizer.param_groups[0]['lr'] = lr

        import matplotlib.pyplot as plt
        lrs = [10**i for i in log_lrs]
        fig = plt.figure()
        plt.xscale('log')
        plt.plot(lrs[10:-5], losses[10:-5])
        plt.savefig(join(self.output_folder, "lr_finder.png"))
        plt.close()
        return log_lrs, losses
Exemplo n.º 26
0
class ClassificationTask(ClassyTask):
    """Basic classification training task.

    This task encapsultates all of the components and steps needed to
    train a classifier using a :class:`classy_vision.trainer.ClassyTrainer`.

    Assumes a train / test phase per each epoch and that the datasets
    have the same API as the map-style Dataset class in
    `torch.utils.data.dataset <https://pytorch.org/docs/stable/data.html
    #torch.utils.data.Dataset>`_ (in particular, this task makes use of
    the len).  If you are using an `IterableDataset <https://pytorch.org/docs/
    stable/data.html#torch.utils.data.IterableDataset>`_ then a custom task
    may be appropriate.


    :var loss: Loss (see :class:`classy_vision.losses.ClassyLoss`) function used
        for computing the loss in each forward pass
    :var datasets: Mapping from a ``phase_type`` in ["train", "test']
        to dataset used for training (or testing)
    :var meters: List of meters (see :class:`classy_vision.meters.ClassyMeter`)
        to calculate during training
    :var num_epochs: Number of epochs (passes over dataset) to train
    :var test_only: Used to only run the test phase
    :var base_model: Model to be trained, unwrapped in DDP or DP wrappers
    :var optimizer: Optimizer used in train step
    :var optimizer_schedulers: Dictionary. Key is the name of the optimizer
        option (e.g. lr), value is a ClassyParamScheduler
    :var checkpoint: Serializable dict which represents state in training
    :var phases: List of phase specific information, e.g. if phase is
        train / test.
    :var hooks: List of hooks to apply during training
    :var train: Phase type, if true it means we are training,
        false means testing
    :var distributed_model: Base model, but wrapped in DDP (DistributedDataParallel)
    :var phase_idx: Current phase id, first phase is 0, if task has not started
        training then returns -1
    :var train_phase_idx: Only counts train phases
    :var num_updates: Number of total parameter updates applied to model
        by the optimizer
    :var data_iterator: Iterator which can be used to obtain batches
    :var losses: Loss curve
    :var perf_log: list of training speed measurements, to be logged
    :var clip_grad_norm: maximum gradient norm (default None)
    :var simulated_global_batchsize: batch size simulated via gradient accumulation
    :var optimizer_period: apply optimizer after this many steps; derived from
        simulated_global_batchsize, default 1.
    """

    def __init__(self):
        """Constructs a ClassificationTask"""
        super().__init__()

        self.base_loss = None
        self.datasets = {}
        self.meters = []
        self.num_epochs = 1
        self.test_phase_period = 1
        self.train_phases_per_epoch = 0
        self.test_only = False
        self.base_model = None
        self.optimizer = None
        self.optimizer_schedulers = {}
        self.checkpoint_dict = None
        self.checkpoint_path = None
        self.phases = []
        self.hooks = []
        self.train = True
        self.distributed_model = None
        self.distributed_loss = None
        self.phase_idx = -1
        self.train_phase_idx = -1
        self.num_updates = 0
        self.dataloader = None
        self.data_iterator = None
        self.losses = []
        self.broadcast_buffers_mode: BroadcastBuffersMode = (
            BroadcastBuffersMode.BEFORE_EVAL
        )
        self.amp_args = None
        self.amp_type = None
        self.amp_grad_scaler = None
        self.mixup_transform = None
        self.perf_log = []
        self.last_batch = None
        self.batch_norm_sync_mode = BatchNormSyncMode.DISABLED
        self.find_unused_parameters = False
        self.use_gpu = torch.cuda.is_available()
        self.dataloader_mp_context = "spawn"
        self.bn_weight_decay = False
        self._train_only = True
        self.clip_grad_norm = None
        self.simulated_global_batchsize = None
        self.optimizer_period = 1
        self.ddp_bucket_cap_mb = 25
        self.use_sharded_ddp = False
        self.fp16_grad_compress = False

    def set_use_sharded_ddp(self, use_sharded_ddp: bool):
        self.use_sharded_ddp = use_sharded_ddp
        if self.use_sharded_ddp:
            logging.info("Using Sharded DDP")
        return self

    def set_use_gpu(self, use_gpu: bool):
        self.use_gpu = use_gpu

        assert (
            not self.use_gpu or torch.cuda.is_available()
        ), "CUDA required to train on GPUs"

        return self

    def set_clip_grad_norm(self, clip_grad_norm: Optional[float]):
        """Sets maximum gradient norm.

        None means gradient clipping is disabled. Defaults to None."""
        self.clip_grad_norm = clip_grad_norm
        if clip_grad_norm is None:
            logging.info("Disabled gradient norm clipping.")
        else:
            logging.info(
                f"Enabled gradient norm clipping with threshold: {clip_grad_norm}"
            )
        return self

    def set_simulated_global_batchsize(self, simulated_global_batchsize: Optional[int]):
        """Sets a simulated batch size by gradient accumulation.

        Gradient accumulation adds up gradients from multiple minibatches and
        steps the optimizer every N train_steps, where N is optimizer_period.
        When enabled, the very last train_steps might end up not updating the
        model, depending on the number of total steps. None means gradient
        accumulation is disabled. Defaults to None."""
        self.simulated_global_batchsize = simulated_global_batchsize
        return self

    def set_checkpoint(self, checkpoint_path: str):
        """Sets checkpoint on task.

        Args:
            checkpoint_path: The path to load the checkpoint from. Can be a file or a
            directory. See :func:`load_checkpoint` for more information.
        """
        self.checkpoint_path = checkpoint_path
        return self

    def _set_checkpoint_dict(self, checkpoint_dict: Dict[str, Any]):
        """Sets the checkpoint dict in the task. Only used for testing.

        Args:
            checkpoint_dict: A serializable dict representing current task state
        """
        self.checkpoint_dict = checkpoint_dict
        return self

    def set_num_epochs(self, num_epochs: Union[int, float]):
        """Set number of epochs to be run.

        Args:
           num_epochs: Number of epochs to run task
        """
        self.num_epochs = num_epochs
        return self

    def set_test_phase_period(self, test_phase_period: int):
        """Set the period of test phase.

        Args:
            test_phase_period: The period of test phase
        """
        self.test_phase_period = test_phase_period
        return self

    def set_dataset(self, dataset: ClassyDataset, phase_type: str):
        """Set dataset for phase type on task

        Args:
            dataset: ClassyDataset for returning samples.
            phase_type: str must be one of "train" or "test"
        """
        assert phase_type in [
            "train",
            "test",
        ], "phase_type must be in ['train', 'test']"
        self.datasets[phase_type] = dataset
        if phase_type == "train":
            self.train_phases_per_epoch = getattr(dataset, "phases_per_epoch", 1)
        else:
            self._train_only = False
        return self

    def set_dataloader_mp_context(self, dataloader_mp_context: Optional[str]):
        """Set the multiprocessing context used by the dataloader.

        The context can be either 'spawn', 'fork', 'forkserver' or None (uses the
        default context). See
        https://docs.python.org/3/library/multiprocessing.html#multiprocessing.get_context
        for more details."""

        self.dataloader_mp_context = dataloader_mp_context
        return self

    def set_optimizer(self, optimizer: ClassyOptimizer):
        """Set optimizer for task

        Args:
            optimizer: optimizer for task
        """
        self.optimizer = optimizer
        return self

    def set_loss(self, loss: ClassyLoss):
        """Set loss function for task

        Args:
            loss: loss for task
        """
        self.base_loss = loss
        return self

    def set_meters(self, meters: List["ClassyMeter"]):
        """Set meters for task

        Args:
            meters: list of meters to compute during training
        """
        self.meters = meters
        return self

    def set_distributed_options(
        self,
        broadcast_buffers_mode: BroadcastBuffersMode = BroadcastBuffersMode.BEFORE_EVAL,
        batch_norm_sync_mode: BatchNormSyncMode = BatchNormSyncMode.DISABLED,
        batch_norm_sync_group_size: int = 0,
        find_unused_parameters: bool = False,
        bucket_cap_mb: int = 25,
        fp16_grad_compress: bool = False,
    ):
        """Set distributed options.

        Args:
            broadcast_buffers_mode: Broadcast buffers mode. See
                :class:`BroadcastBuffersMode` for options.
            batch_norm_sync_mode: Batch normalization synchronization mode. See
                :class:`BatchNormSyncMode` for options.
            batch_norm_sync_group_size: Group size to use for synchronized batch norm.
                0 means that the stats are synchronized across all replicas. For
                efficient synchronization, set it to the number of GPUs in a node (
                usually 8).
            find_unused_parameters: See
                :class:`torch.nn.parallel.DistributedDataParallel` for information.
            bucket_cap_mb: See
                :class:`torch.nn.parallel.DistributedDataParallel` for information.
        Raises:
            RuntimeError: If batch_norm_sync_mode is `BatchNormSyncMode.APEX` and apex
                is not installed.
        """
        self.broadcast_buffers_mode = broadcast_buffers_mode

        if batch_norm_sync_group_size > 0:
            if not batch_norm_sync_mode == BatchNormSyncMode.APEX:
                # this should ideally work with PyTorch Sync BN as well, but it
                # fails while initializing DDP for some reason.
                raise ValueError(
                    "batch_norm_sync_group_size can be > 0 only when "
                    "Apex Synchronized Batch Normalization is being used."
                )
        self.batch_norm_sync_group_size = batch_norm_sync_group_size

        if batch_norm_sync_mode == BatchNormSyncMode.DISABLED:
            logging.info("Synchronized Batch Normalization is disabled")
        else:
            if batch_norm_sync_mode == BatchNormSyncMode.APEX and not apex_available:
                raise RuntimeError("apex is not installed")
            msg = f"Using Synchronized Batch Normalization using {batch_norm_sync_mode}"
            if self.batch_norm_sync_group_size > 0:
                msg += f" and group size {batch_norm_sync_group_size}"
            logging.info(msg)
        self.batch_norm_sync_mode = batch_norm_sync_mode

        if find_unused_parameters:
            logging.info("Enabling find_unused_parameters in DDP")

        self.find_unused_parameters = find_unused_parameters
        self.ddp_bucket_cap_mb = bucket_cap_mb

        if fp16_grad_compress:
            if get_torch_version() < [1, 8]:
                raise RuntimeError(
                    "FP16 grad compression is only supported since PyTorch 1.8"
                )
            logging.info("Enabling FP16 grad compression")
        self.fp16_grad_compress = fp16_grad_compress

        return self

    def set_hooks(self, hooks: List["ClassyHook"]):
        """Set hooks for task

        Args:
            hooks: List of hooks to apply during training
        """
        from classy_vision.hooks import ClassyHook

        assert isinstance(hooks, list)
        assert all(isinstance(hook, ClassyHook) for hook in hooks)
        assert len({hook.name() for hook in hooks}) == len(
            hooks
        ), "Cannot have repeated hooks of the same class"
        # TODO (zyan3): we move checkpoint hook to the end of the list because some hooks
        # may change the state of the model, and we want to save changed state in the checkpoint.
        # This is temporary fix.
        non_checkpoint_hooks = [
            hook for hook in hooks if not isinstance(hook, CheckpointHook)
        ]
        checkpoint_hooks = [hook for hook in hooks if isinstance(hook, CheckpointHook)]
        hooks = non_checkpoint_hooks + checkpoint_hooks
        self.hooks = hooks
        return self

    def set_model(self, model: ClassyModel):
        """Set model for task

        Args:
            model: Model to be trained
        """
        self.base_model = model
        return self

    def set_test_only(self, test_only: bool):
        """Set test only flag

        Args:
            test_only: If true, only test phases will be run
        """
        self.test_only = test_only
        return self

    def set_bn_weight_decay(self, bn_weight_decay: bool):
        assert type(bn_weight_decay) == bool

        self.bn_weight_decay = bn_weight_decay
        return self

    def set_amp_args(self, amp_args: Optional[Dict[str, Any]]):
        """Disable / enable apex.amp and set the automatic mixed precision parameters.

        apex.amp can be utilized for mixed / half precision training.

        Args:
            amp_args: Dictionary containing arguments to be passed to
            amp.initialize. Set to None to disable amp.  To enable mixed
            precision training, pass amp_args={"opt_level": "O1"} here.
            See https://nvidia.github.io/apex/amp.html for more info.

        Raises:
            RuntimeError: If opt_level is not None and apex is not installed.

        Warning: apex needs to be installed to utilize this feature.
        """
        self.amp_args = amp_args

        if amp_args is None:
            logging.info("AMP disabled")
        else:
            # Check that the requested AMP type is known
            try:
                self.amp_type = AmpType[self.amp_args["amp_type"].upper()]
            except KeyError:
                logging.info("AMP type not specified, defaulting to Apex")
                self.amp_type = AmpType.APEX

            # Check for CUDA availability, required for both Apex and Pytorch AMP
            if not torch.cuda.is_available():
                raise RuntimeError(
                    "AMP is required but CUDA is not supported, cannot enable AMP"
                )

            # Check for Apex availability
            if self.amp_type == AmpType.APEX and not apex_available:
                raise RuntimeError(
                    "Apex AMP is required but Apex is not installed, cannot enable AMP"
                )

            if self.use_sharded_ddp:
                if self.amp_type == AmpType.APEX:
                    raise RuntimeError(
                        "ShardedDDP has been requested, which is incompatible with Apex AMP"
                    )

                if not fairscale_available:
                    raise RuntimeError(
                        "ShardedDDP has been requested, but fairscale is not installed in the current environment"
                    )

            # Set Torch AMP grad scaler, used to prevent gradient underflow
            elif self.amp_type == AmpType.PYTORCH:

                if self.use_sharded_ddp:
                    logging.info("Using ShardedGradScaler to manage Pytorch AMP")
                    self.amp_grad_scaler = ShardedGradScaler()
                else:
                    self.amp_grad_scaler = TorchGradScaler()

            logging.info(f"AMP enabled with args {amp_args}")
        return self

    def set_mixup_transform(self, mixup_transform: Optional["MixupTransform"]):
        """Disable / enable mixup transform for data augmentation

        Args::
            mixup_transform: a callable object which performs mixup data augmentation
        """
        self.mixup_transform = mixup_transform
        if mixup_transform is None:
            logging.info("mixup disabled")
        else:
            logging.info("mixup enabled")
        return self

    def set_optimizer_schedulers(self, schedulers):
        self.optimizer_schedulers = schedulers
        return self

    @classmethod
    def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
        """Instantiates a ClassificationTask from a configuration.

        Args:
            config: A configuration for a ClassificationTask.
                See :func:`__init__` for parameters expected in the config.

        Returns:
            A ClassificationTask instance.
        """
        test_only = config.get("test_only", False)
        if not test_only:
            # TODO Make distinction between epochs and phases in optimizer clear
            train_phases_per_epoch = config["dataset"]["train"].get(
                "phases_per_epoch", 1
            )

            optimizer_config = config["optimizer"]
            optimizer_config["num_epochs"] = (
                config["num_epochs"] * train_phases_per_epoch
            )
            optimizer = build_optimizer(optimizer_config)
            param_schedulers = build_optimizer_schedulers(optimizer_config)

        datasets = {}
        phase_types = ["train", "test"]
        for phase_type in phase_types:
            if phase_type in config["dataset"]:
                datasets[phase_type] = build_dataset(config["dataset"][phase_type])
        loss = build_loss(config["loss"])
        amp_args = config.get("amp_args")
        meters = build_meters(config.get("meters", {}))
        model = build_model(config["model"])

        mixup_transform = None
        if config.get("mixup") is not None:
            assert "alpha" in config["mixup"], "key alpha is missing in mixup dict"
            mixup_transform = MixupTransform(
                config["mixup"]["alpha"], config["mixup"].get("num_classes")
            )

        # hooks config is optional
        hooks_config = config.get("hooks")
        hooks = []
        if hooks_config is not None:
            hooks = build_hooks(hooks_config)

        distributed_config = config.get("distributed", {})
        distributed_options = {
            "broadcast_buffers_mode": BroadcastBuffersMode[
                distributed_config.get("broadcast_buffers", "before_eval").upper()
            ],
            "batch_norm_sync_mode": BatchNormSyncMode[
                distributed_config.get("batch_norm_sync_mode", "disabled").upper()
            ],
            "batch_norm_sync_group_size": distributed_config.get(
                "batch_norm_sync_group_size", 0
            ),
            "find_unused_parameters": distributed_config.get(
                "find_unused_parameters", False
            ),
            "bucket_cap_mb": distributed_config.get("bucket_cap_mb", 25),
            "fp16_grad_compress": distributed_config.get("fp16_grad_compress", False),
        }

        task = (
            cls()
            .set_num_epochs(config["num_epochs"])
            .set_test_phase_period(config.get("test_phase_period", 1))
            .set_loss(loss)
            .set_test_only(test_only)
            .set_model(model)
            .set_meters(meters)
            .set_amp_args(amp_args)
            .set_mixup_transform(mixup_transform)
            .set_distributed_options(**distributed_options)
            .set_hooks(hooks)
            .set_bn_weight_decay(config.get("bn_weight_decay", False))
            .set_clip_grad_norm(config.get("clip_grad_norm"))
            .set_simulated_global_batchsize(config.get("simulated_global_batchsize"))
            .set_use_sharded_ddp(config.get("use_sharded_ddp", False))
        )

        if not test_only:
            task.set_optimizer(optimizer)
            task.set_optimizer_schedulers(param_schedulers)

        use_gpu = config.get("use_gpu")
        if use_gpu is not None:
            task.set_use_gpu(use_gpu)

        for phase_type in datasets:
            task.set_dataset(datasets[phase_type], phase_type)

        # NOTE: this is a private member and only meant to be used for
        # logging/debugging purposes. See __repr__ implementation
        task._config = config

        return task

    @property
    def num_batches_per_phase(self):
        """Returns number of batches in current phase iterator"""
        return len(self.data_iterator)

    @property
    def model(self):
        """Returns model used in training (can be wrapped with DDP)"""
        return (
            self.distributed_model if is_distributed_training_run() else self.base_model
        )

    @property
    def loss(self):
        """Returns loss used in training (can be wrapped with DDP)"""
        return self.distributed_loss if self.distributed_loss else self.base_loss

    @property
    def phase_type(self):
        """Returns current phase type. String with value "train" or "test" """
        return "train" if self.train else "test"

    @property
    def eval_phase_idx(self):
        """Returns current evaluation phase"""
        return self.phase_idx - self.train_phase_idx - 1

    def get_total_training_phases(self):
        """
        Returns the total number of "train" phases in the task
        """
        num_training_phases = 0
        for phase in self.phases:
            if phase["train"] is True:
                num_training_phases += 1
        return num_training_phases

    def get_total_test_phases(self):
        """
        Returns the total number of "test" phases in the task
        """
        num_test_phases = 0
        for phase in self.phases:
            if phase["train"] is False:
                num_test_phases += 1
        return num_test_phases

    def _build_phases(self):
        """Returns list of phases from config.

        These phases will look like:
        {
          train: is this a train or test phase?
          optimizer: optimizer settings
        }

        - If this is a test only run, then only test phases will be
        generated
        - If this is a training run with both train and test datasets, then x phases =
          x train phases + x test phases, interleaved. If test_phase_period > 1, test
          phases are only added after test_phase_period train phases. The last phase is
          always a test phase.
        - If this is a training run with only a train dataset, then x phases = x train
          phases.
        """
        if not self.test_only:
            phases = [
                {"train": True}
                for _ in range(math.ceil(self.train_phases_per_epoch * self.num_epochs))
            ]

            if self._train_only:
                return phases

            final_phases = []
            for i, phase in enumerate(phases):
                final_phases.append(phase)
                if (i + 1) % self.test_phase_period == 0:
                    final_phases.append({"train": False})
            if final_phases[-1]["train"]:
                final_phases.append({"train": False})
            return final_phases

        return [{"train": False} for _ in range(self.num_epochs)]

    def build_dataloader_from_dataset(self, dataset, **kwargs):
        """Builds a dataloader from the provided dataset

        Args:
            dataset: A ClassyDataset
            kwargs: Additional kwargs to pass during dataloader construction for
                derived classes
        """
        return dataset.iterator(
            phase_type=self.phase_type,
            current_phase_id=self.train_phase_idx if self.train else 0,
            pin_memory=self.use_gpu and torch.cuda.device_count() > 1,
            multiprocessing_context=mp.get_context(self.dataloader_mp_context),
            **kwargs,
        )

    def build_dataloaders_for_current_phase(self):
        """Builds dataloader(s) for the current phase.

        Deriving classes can override this method to support custom behavior, like
        supporting multiple dataloaders in parallel.
        """
        self.dataloader = self.build_dataloader_from_dataset(
            self.datasets[self.phase_type]
        )

    def prepare_optimizer(self, optimizer, model, loss=None):
        bn_params, other_params = split_batchnorm_params(model)
        if loss is not None:
            bn_params_loss, params_loss = split_batchnorm_params(loss)
            bn_params = bn_params + bn_params_loss
            other_params = other_params + params_loss

        bn_schedulers = self.optimizer_schedulers.copy()
        if not self.bn_weight_decay:
            bn_schedulers["weight_decay"] = 0

        param_groups = [{"params": other_params, **self.optimizer_schedulers}]
        if len(bn_params) > 0:
            param_groups.append({"params": bn_params, **bn_schedulers})
        self.optimizer.set_param_groups(param_groups)

    def prepare(self):
        """Prepares task for training, populates all derived attributes """

        self.phases = self._build_phases()
        self.train = False if self.test_only else self.train

        if self.batch_norm_sync_mode == BatchNormSyncMode.PYTORCH:
            self.base_model = nn.SyncBatchNorm.convert_sync_batchnorm(self.base_model)
        elif self.batch_norm_sync_mode == BatchNormSyncMode.APEX:
            sync_bn_process_group = apex.parallel.create_syncbn_process_group(
                self.batch_norm_sync_group_size
            )
            self.base_model = apex.parallel.convert_syncbn_model(
                self.base_model, process_group=sync_bn_process_group
            )

        # move the model and loss to the right device
        if self.use_gpu:
            self.base_model, self.base_loss = copy_model_to_gpu(
                self.base_model, self.base_loss
            )
        else:
            self.base_loss.cpu()
            self.base_model.cpu()

        if self.optimizer is not None:
            self.prepare_optimizer(
                optimizer=self.optimizer, model=self.base_model, loss=self.base_loss
            )

        if self.amp_args is not None:
            if self.amp_type == AmpType.APEX:
                # Initialize apex.amp. This updates the model and the PyTorch optimizer (
                # if training, which is wrapped by the ClassyOptimizer in self.optimizer).
                # Please note this must happen before loading the checkpoint, cause
                # there's amp state to be restored.
                if self.optimizer is None:
                    self.base_model = apex.amp.initialize(
                        self.base_model, optimizers=None, **self.amp_args
                    )
                else:
                    self.base_model, self.optimizer.optimizer = apex.amp.initialize(
                        self.base_model, self.optimizer.optimizer, **self.amp_args
                    )

        if self.simulated_global_batchsize is not None:
            if self.simulated_global_batchsize % self.get_global_batchsize() != 0:
                raise ValueError(
                    f"Global batch size ({self.get_global_batchsize()}) must divide "
                    f"simulated_global_batchsize ({self.simulated_global_batchsize})"
                )
        else:
            self.simulated_global_batchsize = self.get_global_batchsize()

        self.optimizer_period = (
            self.simulated_global_batchsize // self.get_global_batchsize()
        )
        if self.optimizer_period > 1:
            logging.info(
                f"Using gradient accumulation with a period of {self.optimizer_period}"
            )

        if self.checkpoint_path:
            self.checkpoint_dict = load_and_broadcast_checkpoint(self.checkpoint_path)

        classy_state_dict = (
            None
            if self.checkpoint_dict is None
            else self.checkpoint_dict["classy_state_dict"]
        )

        if classy_state_dict is not None:
            state_load_success = update_classy_state(self, classy_state_dict)
            assert (
                state_load_success
            ), "Update classy state from checkpoint was unsuccessful."

        self.init_distributed_data_parallel_model()

    def init_distributed_data_parallel_model(self):
        """
        Initialize
        `torch.nn.parallel.distributed.DistributedDataParallel <https://pytorch.org/
        docs/stable/nn.html#distributeddataparallel>`_.

        Needed for distributed training. This is where a model should be wrapped by DDP.
        """
        if not is_distributed_training_run():
            return
        assert (
            self.distributed_model is None
        ), "init_ddp_non_elastic must only be called once"

        broadcast_buffers = (
            self.broadcast_buffers_mode == BroadcastBuffersMode.FORWARD_PASS
        )

        if self.use_sharded_ddp:
            if not isinstance(self.optimizer, ZeRO):
                raise ValueError(
                    "ShardedDataParallel engine should only be used in conjunction with ZeRO optimizer"
                )
            from fairscale.nn.data_parallel import ShardedDataParallel

            # Replace the original DDP wrap by the shard-aware ShardedDDP
            self.distributed_model = ShardedDataParallel(
                module=self.base_model,
                sharded_optimizer=self.optimizer.optimizer,
                broadcast_buffers=broadcast_buffers,
            )
        else:
            self.distributed_model = init_distributed_data_parallel_model(
                self.base_model,
                broadcast_buffers=broadcast_buffers,
                find_unused_parameters=self.find_unused_parameters,
                bucket_cap_mb=self.ddp_bucket_cap_mb,
            )
            if self.fp16_grad_compress:

                from torch.distributed.algorithms import ddp_comm_hooks

                # FP16 hook is stateless and only takes a process group as the state.
                # We use the default process group so we set the state to None.
                process_group = None
                self.distributed_model.register_comm_hook(
                    process_group,
                    ddp_comm_hooks.default_hooks.fp16_compress_hook,
                )
        if (
            isinstance(self.base_loss, ClassyLoss)
            and self.base_loss.has_learned_parameters()
        ):
            logging.info("Initializing distributed loss")
            self.distributed_loss = init_distributed_data_parallel_model(
                self.base_loss,
                broadcast_buffers=broadcast_buffers,
                find_unused_parameters=self.find_unused_parameters,
                bucket_cap_mb=self.ddp_bucket_cap_mb,
            )

    @property
    def where(self):
        """Returns the proportion of training that has completed. If in test
        only mode, returns proportion of testing completed

        Returned value is a float in the range [0, 1)
        """
        current_step = self.num_updates / self.get_global_batchsize()
        num_phases = (
            self.get_total_test_phases()
            if self.test_only
            else self.get_total_training_phases()
        )

        if self.num_batches_per_phase <= 0:
            raise RuntimeError("No batches to read. Is the dataset empty?")

        num_steps = num_phases * self.num_batches_per_phase
        where = current_step / num_steps

        return where

    def get_classy_state(self, deep_copy: bool = False):
        """Returns serialiable state of task

        Args:
            deep_copy: If true, does a deep copy of state before returning.
        """
        optimizer_state = {}
        if self.optimizer is not None:
            optimizer_state = self.optimizer.get_classy_state()

        classy_state_dict = {
            "train": self.train,
            "base_model": self.base_model.get_classy_state(),
            "meters": [meter.get_classy_state() for meter in self.meters],
            "optimizer": optimizer_state,
            "phase_idx": self.phase_idx,
            "train_phase_idx": self.train_phase_idx,
            "num_updates": self.num_updates,
            "losses": self.losses,
            "hooks": {hook.name(): hook.get_classy_state() for hook in self.hooks},
            "loss": {},
        }
        if "train" in self.datasets and self._is_checkpointable_dataset(
            self.datasets["train"]
        ):
            classy_state_dict["train_dataset_iterator"] = self.datasets[
                "train"
            ].get_classy_state()

        if isinstance(self.base_loss, ClassyLoss):
            classy_state_dict["loss"] = self.base_loss.get_classy_state()
        if self.amp_args is not None:
            if self.amp_type == AmpType.APEX:
                classy_state_dict["amp"] = apex.amp.state_dict()

            elif self.amp_grad_scaler is not None:
                classy_state_dict["amp"] = self.amp_grad_scaler.state_dict()

        if deep_copy:
            classy_state_dict = copy.deepcopy(classy_state_dict)
        return classy_state_dict

    def set_classy_state(self, state):
        """Set task state

        Args:
            state: Dict containing state of a task
        """
        # some settings are different in test only
        self.train = False if self.test_only else state["train"]
        if not self.test_only:
            self.phase_idx = state["phase_idx"]
            self.num_updates = state["num_updates"]
            self.train_phase_idx = state["train_phase_idx"]
            self.losses = state["losses"]
            for meter, meter_state in zip(self.meters, state["meters"]):
                meter.set_classy_state(meter_state)

        self.base_model.set_classy_state(state["base_model"])
        if self.optimizer is not None:
            self.optimizer.set_classy_state(state["optimizer"])
        if state.get("loss") and isinstance(self.base_loss, ClassyLoss):
            self.base_loss.set_classy_state(state["loss"])

        if "amp" in state:
            if self.amp_type == AmpType.APEX:
                apex.amp.load_state_dict(state["amp"])
            else:
                self.amp_grad_scaler.load_state_dict(state["amp"])

        for hook in self.hooks:
            # we still want to be able to run when new hooks are added or old
            # hooks are removed
            if hook.name() in state["hooks"]:
                hook.set_classy_state(state["hooks"][hook.name()])
            else:
                logging.warning(f"No state found for hook: {hook.name()}")

        if "train" in self.datasets and self._is_checkpointable_dataset(
            self.datasets["train"]
        ):
            self.datasets["train"].set_classy_state(state.get("train_dataset_iterator"))

    @staticmethod
    def _is_checkpointable_dataset(dataset):
        return hasattr(dataset, "get_classy_state") and hasattr(
            dataset, "set_classy_state"
        )

    def eval_step(self):
        self.last_batch = None

        # Process next sample
        with Timer() as timer:
            sample = next(self.data_iterator)

        assert isinstance(sample, dict) and "input" in sample and "target" in sample, (
            f"Returned sample [{sample}] is not a map with 'input' and"
            + "'target' keys"
        )

        target = sample["target"]
        if self.use_gpu:
            sample = recursive_copy_to_gpu(sample, non_blocking=True)

        # Optional Pytorch AMP context
        torch_amp_context = (
            torch.cuda.amp.autocast()
            if self.amp_type == AmpType.PYTORCH
            else contextlib.suppress()
        )

        with torch.no_grad(), torch_amp_context:
            output = self.model(sample["input"])

            local_loss = self.compute_loss(output, sample)

            loss = local_loss.detach().clone()

            self.check_inf_nan(loss)

            self.losses.append(loss.data.cpu().item())

            self.update_meters(output, sample)

        # Move some data to the task so hooks get a chance to access it
        self.last_batch = LastBatchInfo(
            loss=loss,
            output=output,
            target=target,
            sample=sample,
            step_data={"sample_fetch_time": timer.elapsed_time},
        )

    def check_inf_nan(self, loss):
        if loss == float("inf") or loss == float("-inf") or loss != loss:
            raise FloatingPointError(f"Loss is infinity or NaN: {loss}")

    def _should_do_step(self):
        """Tells if we will be performing an optimizer step.

        Returns True always if there is no gradient accumulation. With gradient
        accumulation returns True only when the gradients will be synchronized and we
        will be performing an optimizer step.
        """
        update_idx = self.num_updates // self.get_global_batchsize()
        return (update_idx % self.optimizer_period) == self.optimizer_period - 1

    def train_step(self):
        """Train step to be executed in train loop."""

        self.last_batch = None

        # Process next sample
        with Timer() as timer:
            sample = next(self.data_iterator)

        assert isinstance(sample, dict) and "input" in sample and "target" in sample, (
            f"Returned sample [{sample}] is not a map with 'input' and"
            + "'target' keys"
        )

        # Copy sample to GPU
        target = sample["target"]
        if self.use_gpu:
            sample = recursive_copy_to_gpu(sample, non_blocking=True)

        if self.mixup_transform is not None:
            sample = self.mixup_transform(sample)

        # Optional Pytorch AMP context
        torch_amp_context = (
            torch.cuda.amp.autocast()
            if self.amp_type == AmpType.PYTORCH
            else contextlib.suppress()
        )

        # only sync with DDP when we need to perform an optimizer step
        # an optimizer step can be skipped if gradient accumulation is enabled
        do_step = self._should_do_step()
        ctx_mgr_model = (
            self.distributed_model.no_sync()
            if self.distributed_model is not None and not do_step
            else contextlib.suppress()
        )
        ctx_mgr_loss = (
            self.distributed_loss.no_sync()
            if self.distributed_loss is not None and not do_step
            else contextlib.suppress()
        )

        with ctx_mgr_model, ctx_mgr_loss:
            # Forward pass
            with torch.enable_grad(), torch_amp_context:
                output = self.model(sample["input"])

                local_loss = self.compute_loss(output, sample)
                loss = local_loss.detach().clone()
                self.losses.append(loss.data.cpu().item())

                self.update_meters(output, sample)

            # Backwards pass + optimizer step
            self.run_optimizer(local_loss)

        self.num_updates += self.get_global_batchsize()

        # Move some data to the task so hooks get a chance to access it
        self.last_batch = LastBatchInfo(
            loss=loss,
            output=output,
            target=target,
            sample=sample,
            step_data={"sample_fetch_time": timer.elapsed_time},
        )

    def compute_loss(self, model_output, sample):
        return self.loss(model_output, sample["target"])

    def run_optimizer(self, loss):
        """Runs backwards pass and update the optimizer"""

        self.check_inf_nan(loss)

        # Gradient accumulation logic. We always set optimizer_period, even
        # if gradient accumulation is disabled. Assumes all batches have the
        # same size
        update_idx = self.num_updates // self.get_global_batchsize()
        do_zero_grad = (update_idx % self.optimizer_period) == 0
        do_step = self._should_do_step()

        if do_zero_grad:
            self.optimizer.zero_grad()

        if self.amp_type == AmpType.APEX:
            with apex.amp.scale_loss(loss, self.optimizer.optimizer) as scaled_loss:
                scaled_loss.backward()
        elif self.amp_type == AmpType.PYTORCH:
            self.amp_grad_scaler.scale(loss).backward()
        else:
            loss.backward()

        if do_step:
            # Handle gradient accumulation related gradient rescaling
            if self.optimizer_period != 1:
                self._rescale_gradients(1 / self.optimizer_period)

            # Clipping must happen after grad accumulation
            if self.clip_grad_norm is not None:
                self._clip_gradients(self.clip_grad_norm)

            if self.amp_type == AmpType.PYTORCH:
                # If using mixed precision, handle underflow-related scaling
                # See https://pytorch.org/docs/stable/amp.html#gradient-scaling
                # for context
                self.amp_grad_scaler.step(self.optimizer, where=self.where)
                self.amp_grad_scaler.update()
            else:
                self.optimizer.step(where=self.where)

    def _rescale_gradients(self, scale):
        for param in master_params(self.optimizer):
            if param.grad is not None:
                param.grad.data.mul_(scale)

    def _clip_gradients(self, max_norm):
        nn.utils.clip_grad_norm_(master_params(self.optimizer), max_norm)

    def update_meters(self, model_output, sample):
        target = sample["target"].detach().cpu()
        model_output = model_output.detach().cpu()

        # Update meters
        for meter in self.meters:
            meter.update(model_output, target, is_train=self.train)

    def synchronize_losses(self):
        """Average the losses across the different replicas"""

        # Average losses across nodes
        losses_tensor = torch.tensor(self.losses)
        synchronized_losses_tensor = all_reduce_mean(losses_tensor)
        self.losses = synchronized_losses_tensor.tolist()

    def advance_phase(self):
        """Performs bookkeeping / task updates between phases

        Increments phase idx, resets meters, resets loss history,
        resets counters, shuffles dataset, rebuilds iterators, and
        sets the train / test state for phase.
        """
        logging.debug("Advancing phase")
        # Reset meters for next phase / epoch
        for meter in self.meters:
            meter.reset()

        # Reset loss history for next epoch
        self.losses = []

        # Setup new phase
        self.phase_idx += 1
        phase = self.phases[self.phase_idx]
        self.train = True if phase["train"] else False
        if self.train:
            self.train_phase_idx += 1

        # Re-build dataloader & re-create iterator anytime membership changes.
        self.build_dataloaders_for_current_phase()
        self.create_data_iterators()
        # Set up pytorch module in train vs eval mode, update optimizer.
        self._set_model_train_mode()

    def done_training(self):
        """Stop condition for training"""
        return self.phase_idx + 1 >= len(self.phases)

    def create_data_iterators(self):
        """Creates data iterator(s) for the current phase."""
        # Delete iterator explicitly so that all dataloader processes
        # are cleaned up.
        del self.data_iterator
        self.data_iterator = iter(self.dataloader)

    def _set_model_train_mode(self):
        """Set train mode for model"""
        phase = self.phases[self.phase_idx]
        self.base_model.train(phase["train"])
        self.base_loss.train(phase["train"])

        if (
            self.broadcast_buffers_mode == BroadcastBuffersMode.BEFORE_EVAL
            and not self.train
        ):
            self._broadcast_buffers()

    def _broadcast_buffers(self):
        """Explicitly synchronize buffers across all devices."""
        if self.distributed_model is None:
            return
        buffers = list(self.base_model.buffers())
        if len(buffers) > 0:
            logging.info("Synchronizing buffers before evaluation.")
            for buffer in buffers:
                broadcast(buffer, 0, group=self.distributed_model.process_group)

    # TODO: Functions below should be better abstracted into the dataloader
    # abstraction
    def get_batchsize_per_replica(self):
        """Return local replica's batchsize for dataset (e.g. batchsize per GPU)"""
        return self.datasets[self.phase_type].get_batchsize_per_replica()

    def get_global_batchsize(self):
        """Return global batchsize across all trainers"""
        return self.datasets[self.phase_type].get_global_batchsize()

    def on_start(self):
        for hook in self.hooks:
            hook.on_start(self)

    def on_phase_start(self):
        self.phase_start_time_total = time.perf_counter()

        self.advance_phase()

        for hook in self.hooks:
            hook.on_phase_start(self)

        self.phase_start_time_train = time.perf_counter()

    def on_phase_end(self):
        self.log_phase_end("train")

        if self.train:
            self.optimizer.on_epoch(where=self.where)

        logging.debug("Syncing losses on phase end...")
        self.synchronize_losses()
        logging.debug("...losses synced")

        logging.debug("Syncing meters on phase end...")
        for meter in self.meters:
            meter.sync_state()
        logging.debug("...meters synced")
        barrier()

        for hook in self.hooks:
            hook.on_phase_end(self)
        self.perf_log = []

        self.log_phase_end("total")

    def on_end(self):
        for hook in self.hooks:
            hook.on_end(self)

    def log_phase_end(self, tag):
        if not self.train:
            return

        start_time = (
            self.phase_start_time_train
            if tag == "train"
            else self.phase_start_time_total
        )
        phase_duration = time.perf_counter() - start_time
        im_per_sec = (
            self.get_global_batchsize() * self.num_batches_per_phase
        ) / phase_duration
        self.perf_log.append(
            {
                "tag": tag,
                "phase_idx": self.train_phase_idx,
                "epoch_duration": phase_duration,
                "im_per_sec": im_per_sec,
            }
        )

    def __repr__(self):
        if hasattr(self, "_config"):
            config = json.dumps(self._config, indent=4)
            return f"{super().__repr__()} initialized with config:\n{config}"

        return super().__repr__()
Exemplo n.º 27
0
# create DataLoaders with samplers
train_dataloader = DataLoader(train_ds,
                              batch_size=100,
                              sampler=train_sampler,
                              shuffle=False)

valid_dataloader = DataLoader(valid_ds, batch_size=100, shuffle=True)

test_dataloader = DataLoader(test_ds, batch_size=100, shuffle=True)

# set LR scheduler
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=0.01, total_steps=len(train_dataloader) * epochs)

# create gradient scaler for mixed precision
scaler = GradScaler()


# train function
def train(dataloader):
    pbar = ProgressBar(n_total=len(dataloader), desc='Training')
    train_loss = AverageMeter()
    model.train()
    for batch_idx, batch in enumerate(dataloader):
        b_features, b_target, b_idx = batch['features'].to(
            DEVICE), batch['target'].to(DEVICE), batch['idx'].to(DEVICE)
        optimizer.zero_grad()
        with autocast():
            logits, probs = model(b_features)
            loss = F.cross_entropy(logits, b_target)
        scaler.scale(loss).backward()
Exemplo n.º 28
0
    def set_amp_args(self, amp_args: Optional[Dict[str, Any]]):
        """Disable / enable apex.amp and set the automatic mixed precision parameters.

        apex.amp can be utilized for mixed / half precision training.

        Args:
            amp_args: Dictionary containing arguments to be passed to
            amp.initialize. Set to None to disable amp.  To enable mixed
            precision training, pass amp_args={"opt_level": "O1"} here.
            See https://nvidia.github.io/apex/amp.html for more info.

        Raises:
            RuntimeError: If opt_level is not None and apex is not installed.

        Warning: apex needs to be installed to utilize this feature.
        """
        self.amp_args = amp_args

        if amp_args is None:
            logging.info("AMP disabled")
        else:
            # Check that the requested AMP type is known
            try:
                self.amp_type = AmpType[self.amp_args["amp_type"].upper()]
            except KeyError:
                logging.info("AMP type not specified, defaulting to Apex")
                self.amp_type = AmpType.APEX

            # Check for CUDA availability, required for both Apex and Pytorch AMP
            if not torch.cuda.is_available():
                raise RuntimeError(
                    "AMP is required but CUDA is not supported, cannot enable AMP"
                )

            # Check for Apex availability
            if self.amp_type == AmpType.APEX and not apex_available:
                raise RuntimeError(
                    "Apex AMP is required but Apex is not installed, cannot enable AMP"
                )

            if self.use_sharded_ddp:
                if self.amp_type == AmpType.APEX:
                    raise RuntimeError(
                        "ShardedDDP has been requested, which is incompatible with Apex AMP"
                    )

                if not fairscale_available:
                    raise RuntimeError(
                        "ShardedDDP has been requested, but fairscale is not installed in the current environment"
                    )

            # Set Torch AMP grad scaler, used to prevent gradient underflow
            elif self.amp_type == AmpType.PYTORCH:

                if self.use_sharded_ddp:
                    logging.info("Using ShardedGradScaler to manage Pytorch AMP")
                    self.amp_grad_scaler = ShardedGradScaler()
                else:
                    self.amp_grad_scaler = TorchGradScaler()

            logging.info(f"AMP enabled with args {amp_args}")
        return self
Exemplo n.º 29
0
class Learner(object):
    def __init__(self,
                 model,
                 optimizer,
                 loss_func,
                 name="",
                 scheduler=None,
                 device='cpu'):
        self.model = model
        self.optimizer = optimizer
        self.loss_func = loss_func
        self.scheduler = scheduler
        self.scaler = None
        self.device = device
        self.metric = None
        self.name = name
        self.log = {}
        self.eth = 0.99
        self.do_autocast = False

    def init_amp(self,
                 init_scale=65536.0,
                 growth_factor=2.0,
                 backoff_factor=0.5,
                 growth_interval=2000,
                 enabled=True,
                 do_autocast=True):
        self.do_autocast = do_autocast
        if GradScaler is not None:
            self.scaler = GradScaler(init_scale=init_scale,
                                     growth_factor=growth_factor,
                                     backoff_factor=backoff_factor,
                                     growth_interval=growth_interval,
                                     enabled=True)

    def get_y(self, batch):
        # get Y from Batch, the default is batch[-1] but you can overwrite it
        return batch[-1]

    def get_inds(self, batch):
        # get Y from Batch, the default is batch[-1] but you can overwrite it
        return batch[-1]

    def get_x(self, batch):
        # get x from Batch, the default is batch[:-1] but you can overwrite it
        if isinstance(batch, (list, tuple)):
            return batch[:-1]
        else:
            return [batch]

    def run_model(self, model, batch):
        return model(*(x.to(self.device) for x in self.get_x(batch)))

    def calc_loss(self, y_pred, y_true):
        return self.loss_func(y_pred, y_true.to(self.device))

    def one_cycle(self, batch, train=True, do_step=True):
        device = self.device
        self.preprocess_batch(batch, train)
        y_true = self.get_y(batch)
        if autocast is None:
            y_pred = self.run_model(self.model, batch)
            loss = self.calc_loss(y_pred, y_true)
            loss_item = 0 if np.isnan(loss.item()) else loss.item()
        else:
            with autocast(self.do_autocast):
                y_pred = self.run_model(self.model, batch)
                loss = self.calc_loss(y_pred, y_true)
                loss_item = 0 if np.isnan(loss.item()) else loss.item()
        if train:
            if self.scaler is not None:
                self.scaler.scale(loss).backward()
            else:
                loss.backward()
            if do_step:
                if self.scaler is not None:
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                else:
                    self.optimizer.step()
                if self.scheduler is not None:
                    self.scheduler.step()
                self.optimizer.zero_grad()
            if np.isnan(loss.item()):
                print('got loss = nan')
            loss_item = 0 if np.isnan(loss.item()) else loss.item()
        return loss_item if train else (loss_item, y_pred.to('cpu').detach())

    def one_training_epoch(self, dl, accumulation_steps=1):
        device = self.device
        torch.cuda.empty_cache()
        avg_loss = 0.
        lossf = 0.
        self.model = self.model.train()
        self.model.zero_grad()
        tk0 = notebook.tqdm(dl)
        for i, batch in enumerate(tk0):
            do_step = (i + 1) % accumulation_steps == 0
            loss_item = self.one_cycle(batch, train=True, do_step=do_step)
            e = min(self.eth, 1 - 1.0 / (i + 1.0))
            lossf = e * lossf + (1 - e) * loss_item
            tk0.set_postfix(loss=lossf)
            avg_loss += loss_item / len(dl)
        tk0.disable = False
        tk0.set_postfix(loss=avg_loss)
        tk0.disable = True
        return avg_loss

    def agg_tta(self, y):
        return np.stack(y,0).mean(0) if not isinstance(y[0],tuple)\
               else tuple(np.stack([yy[i] for yy in y],0).mean(0) for i in range(len(y[0])))

    def preprocess_batch(self, batch, train=True):
        return (batch)

    def one_eval_epoch(self, dl, tta=1):
        device = self.device
        avg_loss = 0.
        avg_accuracy = 0.
        lossf = 0
        self.model = self.model.eval()
        predss = []
        with torch.no_grad():
            for t in range(tta):
                pred_list = []
                true_list = []
                tk0 = notebook.tqdm(dl)
                for i, batch in enumerate(tk0):
                    loss_item, y_pred = self.one_cycle(batch,
                                                       train=False,
                                                       do_step=False)
                    pred_list.append(y_pred.to('cpu').numpy() if not isinstance(y_pred,tuple) else\
                        tuple(y.to('cpu').numpy() for y in y_pred))
                    y_batch = self.get_y(batch)
                    true_list.append(y_batch.to('cpu').numpy() if not isinstance(y_batch,tuple) else\
                        tuple(y.to('cpu').numpy() for y in y_batch))
                    e = min(self.eth, 1 - 1.0 / (i + 1.0))
                    lossf = e * lossf + (1 - e) * loss_item
                    tk0.set_postfix(loss=lossf)
                    avg_loss += loss_item / len(dl)
#                 y_true=np.concatenate(true_list,0)
                y_true=np.concatenate(true_list,0) if not isinstance(true_list[0],tuple) else\
                    tuple(np.concatenate([p[i] for p in true_list],0) for i in range(len(true_list[0])))
                predss.append(np.concatenate(pred_list,0) if not isinstance(pred_list[0],tuple) else\
                    tuple(np.concatenate([p[i] for p in pred_list],0) for i in range(len(pred_list[0]))))

            preds = self.agg_tta(predss, 0) if tta > 1 else predss[0]
            m = dict() if self.metric is None else self.metric(preds, y_true)
        tk0.disable = False
        tk0.set_postfix(loss=avg_loss, **m)
        tk0.disable = True
        return avg_loss, m

    def send_log(self, **kwargs):
        log = {'model': self.name}
        log.update(kwargs)
        try:
            sandesh.send(log)
        except:
            print(log)

    def save2log(self, **kwargs):
        for key in kwargs.keys():
            if key not in self.log:
                self.log[key] = []
            self.log[key].append(kwargs[key])

    def evaluate(self, ds, num_workers=8, tta=1, dl_args={'shuffle': False}):
        dl = D.DataLoader(ds, num_workers=num_workers, **dl_args)
        return self.one_eval_epoch(dl, tta=tta)

    def fit(self,
            num_epoches,
            train_ds,
            validate_ds=None,
            batch_size=None,
            lr=None,
            accumulation_steps=1,
            num_workers=8,
            send_log=True,
            eval_batch=None,
            reset_best=False,
            make_best=True,
            tta=1,
            train_dl_args={'shuffle': True},
            val_dl_args={'shuffle': False},
            save_checkpoint='best',
            path=''):
        if batch_size is not None:
            train_dl_args['batch_size'] = batch_size
            val_dl_args['batch_size'] = batch_size
        if eval_batch is not None:
            val_dl_args['batch_size'] = eval_batch

        tq = notebook.tqdm(range(num_epoches))
        if lr is not None:
            self.set_lr(lr)
        if reset_best or not hasattr(self, 'best_metric'):
            self.best_model = None
            self.best_metric = np.inf
        for k, epoch in enumerate(tq):
            self.on_epoch_begin(epoch,
                                train_ds=train_ds,
                                validate_ds=validate_ds)
            dl = D.DataLoader(train_ds,
                              num_workers=num_workers,
                              **train_dl_args)
            if next(self.model.parameters()).device != torch.device('cpu'):
                torch.cuda.empty_cache()
            tavg_loss = self.one_training_epoch(
                dl, accumulation_steps=accumulation_steps)
            #             dl=D.DataLoader(validate_ds, batch_size=batch_size if eval_batch is None else eval_batch,
            #                              num_workers=num_workers,**val_dl_args)
            if validate_ds is not None:
                avg_loss, metric = self.evaluate(validate_ds,
                                                 num_workers=num_workers,
                                                 dl_args=val_dl_args,
                                                 tta=tta)
            else:
                avg_loss = tavg_loss
                metric = {}
            if send_log:
                self.send_log(epoch=epoch,
                              tloss=tavg_loss,
                              loss=avg_loss,
                              **metric)
            self.save2log(epoch=epoch,
                          tloss=tavg_loss,
                          loss=avg_loss,
                          **metric)
            m = avg_loss if 'metric' not in metric.keys() else metric['metric']
            if save_checkpoint == 'last':
                self.save_checkpoint(path)
            if m < self.best_metric:
                self.best_metric = m
                self.best_model = copy.deepcopy(self.model.state_dict())
                tq.set_postfix(best_metric=self.best_metric)
                if save_checkpoint == 'best':
                    self.save_checkpoint(path)
            self.on_epoch_end(epoch)

        print('best metric:', self.best_metric)
        if make_best:
            self.model.load_state_dict(self.best_model)

    def save_model(self, path, name=None):
        name = self.name if name is None else name
        torch.save(self.model.state_dict(), f'{path}{name}')

    def load_model(self, path, name=None, map_location=None):
        name = self.name if name is None else name
        self.model.load_state_dict(
            torch.load(f'{path}{name}', map_location=map_location))

    def save_checkpoint(self, path, name=None):
        name = self.name + '.chk' if name is None else name
        checkpoint = {
            'model': self.model.state_dict(),
            'best_model': self.best_model,
            'best_metric': self.best_metric,
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'log': self.log
        }
        if self.scaler:
            checkpoint['scaler'] = self.scaler.state_dict()
        torch.save(checkpoint, f'{path}{name}')

    def load_checkpoint(self, path, name=None):
        name = self.name + '.chk' if name is None else name + '.chk'
        checkpoint = torch.load(f'{path}{name}')
        self.model.load_state_dict(checkpoint['model'])
        self.best_model = checkpoint['best_model']
        self.best_metric = checkpoint['best_metric']
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.log = checkpoint['log']
        if 'scaler' in checkpoint.keys():
            self.scaler = GradScaler()
            self.scaler.load_state_dict(checkpoint['scaler'])
        else:
            self.scaler = None

    def set_lr(self, lr):
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def on_epoch_begin(self, *args, **kargs):
        pass

    def on_epoch_end(self, *args, **kargs):
        pass

    def predict(self,
                ds,
                batch_size=None,
                num_workers=8,
                dl_args={'shuffle': False},
                return_inds=False,
                return_true=False,
                verbose=True,
                do_eval=True):
        device = self.device
        if batch_size is not None:
            dl_args['batch_size'] = batch_size
        dl = D.DataLoader(ds, num_workers=num_workers, **dl_args)
        pred_list = []
        inds_list = []
        true_list = []
        if do_eval:
            self.model = self.model.eval()
        with torch.no_grad():
            tk0 = notebook.tqdm(dl) if verbose else dl
            for i, batch in enumerate(tk0):
                if autocast is None:
                    y_pred = self.run_model(self.model, batch)
                else:
                    with autocast(self.scaler is not None):
                        y_pred = self.run_model(self.model, batch)
                if return_inds:
                    inds_list.append(self.get_inds(batch).to('cpu').numpy())
                if return_true:
                    yb = self.get_y(batch)
                    true_list.append(yb.to('cpu').numpy() if not isinstance(yb,tuple) else\
                                 tuple(y.to('cpu').numpy() for y in yb))
                pred_list.append(y_pred.to('cpu').numpy() if not isinstance(y_pred,tuple) else\
                                 tuple(y.to('cpu').numpy() for y in y_pred))
        pred = np.concatenate(pred_list,0) if not isinstance(pred_list[0],tuple) else\
                tuple(np.concatenate([p[i] for p in pred_list],0) for i in range(len(pred_list[0])))
        out = ()
        if return_inds:
            out = out + (np.concatenate(inds_list, 0), )
        if return_true:
            rt=np.concatenate(true_list,0) if not isinstance(true_list[0],tuple) else\
                    tuple(np.concatenate([p[i] for p in true_list],0) for i in range(len(true_list[0])))
            out = out + (rt, )

        return pred if len(out) == 0 else (pred, ) + out
Exemplo n.º 30
0
def train_schedule(writer, loader, val_num_steps, validation_loader, device,
                   criterion, net, optimizer, lr_scheduler, num_epochs,
                   is_mixed_precision, num_classes, categories, input_sizes,
                   selector, classes, encoder_only):
    # Poly training schedule
    # Validate and find the best snapshot
    best_mIoU = 0
    net.train()
    epoch = 0
    running_loss = 0.0
    loss_num_steps = int(len(loader) / 10)
    if is_mixed_precision:
        scaler = GradScaler()

    # Training
    while epoch < num_epochs:
        net.train()
        conf_mat = ConfusionMatrix(num_classes)
        time_now = time.time()
        for i, data in enumerate(loader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()

            with autocast(is_mixed_precision):
                outputs = net(inputs)['out']

                if encoder_only:
                    labels = labels.unsqueeze(0)
                    if labels.dtype not in (torch.float32, torch.float64):
                        labels = labels.to(torch.float32)
                    labels = torch.nn.functional.interpolate(
                        labels, size=input_sizes[1], mode='nearest')
                    labels = labels.to(torch.int64)
                    labels = labels.squeeze(0)
                else:
                    outputs = torch.nn.functional.interpolate(
                        outputs,
                        size=input_sizes[0],
                        mode='bilinear',
                        align_corners=True)
                conf_mat.update(labels.flatten(), outputs.argmax(1).flatten())
                loss = criterion(outputs, labels)

            if is_mixed_precision:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()

            lr_scheduler.step()
            running_loss += loss.item()
            current_step_num = int(epoch * len(loader) + i + 1)

            # Record losses
            if current_step_num % loss_num_steps == (loss_num_steps - 1):
                print('[%d, %d] loss: %.4f' %
                      (epoch + 1, i + 1, running_loss / loss_num_steps))
                writer.add_scalar('training loss',
                                  running_loss / loss_num_steps,
                                  current_step_num)
                running_loss = 0.0

            # Validate and find the best snapshot
            if current_step_num % val_num_steps == (val_num_steps - 1):
                test_pixel_accuracy, test_mIoU = test_one_set(
                    loader=validation_loader,
                    device=device,
                    net=net,
                    num_classes=num_classes,
                    categories=categories,
                    output_size=input_sizes[2],
                    labels_size=input_sizes[1],
                    selector=selector,
                    classes=classes,
                    is_mixed_precision=is_mixed_precision,
                    encoder_only=encoder_only)
                writer.add_scalar('test pixel accuracy', test_pixel_accuracy,
                                  current_step_num)
                writer.add_scalar('test mIoU', test_mIoU, current_step_num)
                net.train()

                # Record best model (straight to disk)
                if test_mIoU > best_mIoU:
                    best_mIoU = test_mIoU
                    save_checkpoint(net=net,
                                    optimizer=optimizer,
                                    lr_scheduler=lr_scheduler)

        # Evaluate training accuracies (same metric as validation, but must be on-the-fly to save time)
        with autocast(is_mixed_precision):
            acc_global, acc, iu = conf_mat.compute()
        print(categories)
        print(('global correct: {:.2f}\n'
               'average row correct: {}\n'
               'IoU: {}\n'
               'mean IoU: {:.2f}').format(
                   acc_global.item() * 100,
                   ['{:.2f}'.format(i) for i in (acc * 100).tolist()],
                   ['{:.2f}'.format(i) for i in (iu * 100).tolist()],
                   iu.mean().item() * 100))

        train_pixel_acc = acc_global.item() * 100
        train_mIoU = iu.mean().item() * 100
        writer.add_scalar('train pixel accuracy', train_pixel_acc, epoch + 1)
        writer.add_scalar('train mIoU', train_mIoU, epoch + 1)

        epoch += 1
        print('Epoch time: %.2fs' % (time.time() - time_now))