예제 #1
0
def start_train():
    '''
    训练
    '''
    use_amp = True
    # 前向反传N次,再更新参数  目的:增大batch(理论batch= batch_size * N)
    iter_size = 8

    myNet = MyNet(use_amp).to("cuda:0")
    myNet = torch.nn.DataParallel(myNet, device_ids=[0, 1])  # 数据并行
    myNet.train()
    # 训练开始前初始化 梯度缩放器
    scaler = GradScaler() if use_amp else None

    # 加载预训练权重
    if resume_train:
        scaler.load_state_dict(checkpoint['scaler'])  # amp自动混合精度用到
        optimizer.load_state_dict(checkpoint['optimizer'])
        myNet.load_state_dict(checkpoint["model"])

    for epoch in range(1, 100):
        for batch_idx, (input, target) in enumerate(dataloader_train):

            # 数据 转到每个并行模型的主卡上
            input = input.to("cuda:0")
            target = target.to("cuda:0")

            # 自动混合精度训练
            if use_amp:
                # 自动广播 将支持半精度操作自动转为FP16
                with autocast():
                    # 提取特征
                    feature = myNet(input)
                    losses = loss_function(target, feature)
                    loss = losses / iter_size
                scaler.scale(loss).backward()
            else:
                feature = myNet(input, target)
                losses = loss_function(target, feature)
                loss = losses / iter_size
                loss.backward()

            # 梯度累积,再更新参数
            if (batch_idx + 1) % iter_size == 0:
                # 梯度更新
                if use_amp:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()
                # 梯度清零
                optimizer.zero_grad()
        # scaler 具有状态。恢复训练时需要加载
        state = {
            'net': myNet.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scaler': scaler.state_dict()
        }
        torch.save(state, "filename.pth")
예제 #2
0
    class AMPTrainer(cls):
        """Pytorch's automatic mixed precision
        requires: pytorch >= 1.6
        see: https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html
        """
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            print('init pytorch\'s amp')
            self.scaler = GradScaler()

        def forward_pass(self, data, vars: StageVars):
            with autocast():
                return super().forward_pass(data=data, vars=vars)

        def backward_pass(self, vars: StageVars):
            loss = vars.forward['loss']
            if loss is not None:
                assert loss.dim() == 0, "loss must be reduced"
                with time_elapsed_to_profiler('backward'):
                    self.opt.zero_grad()
                    self.scaler.scale(loss).backward()
                    # allow modifying the gradient directly
                    self.scaler.unscale_(self.opt)

        def optimize(self, vars: StageVars):
            loss = vars.forward['loss']
            if loss is not None:
                with time_elapsed_to_profiler('optimize'):
                    self.scaler.step(self.opt)
                    self.scaler.update()

        def get_state(self):
            # including the amp state
            state = super().get_state()
            state['scaler'] = self.scaler.state_dict()
            return state

        def load_state(self, state):
            # including the amp state
            super().load_state(state)
            print('loading pytorch\'s amp state ...')
            if 'scaler' in state:
                self.scaler.load_state_dict(state['scaler'])
            else:
                print('warning: scaler state is not available')

        def __repr__(self):
            return f'<AMPTrainer {super().__repr__()}>'
예제 #3
0
def load_model(model_name, model, device, mp=False):
    filepath = os.path.join('models', model_name + '.pt')
    checkpoint = torch.load(filepath)

    model.load_state_dict(checkpoint['model'])
    model.eval()

    optimizer = Adam(model.parameters())
    optimizer.load_state_dict(checkpoint['optimizer'])

    scaler = GradScaler(enabled=mp)
    scaler.load_state_dict(checkpoint['scaler'])
    # scaler.set_growth_interval(500)
    # scaler.set_growth_factor(1)
    # scaler.set_backoff_factor(1)
    # print('Scale', scaler.get_scale())

    epoch = checkpoint['epoch'] or 0
    return model, optimizer, scaler, epoch
예제 #4
0
class NativeScaler:
    state_dict_key = "amp_scaler"

    def __init__(self):
        self._scaler = GradScaler()

    def __repr__(self) -> str:
        return repr(self.__class__.__name__)

    def __call__(
        self,
        loss,
        optimizer,
        step,
        accum_grad,
        clip_grad=None,
        parameters=None,
        create_graph=False,
    ):
        self._scaler.scale(loss /
                           accum_grad).backward(create_graph=create_graph)
        if step % accum_grad == 0:
            if clip_grad is not None:
                assert parameters is not None
                self._scaler.unscale_(
                    optimizer
                )  # unscale the gradients of optimizer's assigned params in-place
                nn.utils.clip_grad_norm_(parameters, clip_grad)
            self._scaler.step(optimizer)
            self._scaler.update()
            optimizer.zero_grad()

    def state_dict(self):
        return self._scaler.state_dict()

    def load_state_dict(self, state_dict):
        self._scaler.load_state_dict(state_dict)
예제 #5
0
class Trainer(object):
    def __init__(self, cfgs):

        save_dict = OrderedDict()

        save_dict["fold"] = cfgs["fold"]
        if cfgs["memo"] is not None:
            save_dict["memo"] = cfgs["memo"]  # 1,2,3
        specific_dir = ["{}-{}".format(key, save_dict[key]) for key in save_dict.keys()]

        cfgs["save_dir"] = os.path.join(
            cfgs["save_dir"],
            # cfgs["model"]["meta"],
            # cfgs["model"]["inputs"]["label"],
            "_".join(specific_dir),
        )
        os.makedirs(cfgs["save_dir"], exist_ok=True)

        ####### CONFIGS
        self.cfgs = cfgs

        ####### Logging
        self.tb_writer = utils.get_writer(self.cfgs)
        self.txt_logger = utils.get_logger(self.cfgs)

        self.do_logging = True
        if len(self.cfgs["gpu"]) > 1:
            if dist.get_rank() != 0:
                self.do_logging = False

        if self.do_logging:
            self.txt_logger.write("\n\n----train.py----")
            self.txt_logger.write("\n{}".format(datetime.datetime.now()))
            self.txt_logger.write(
                "\n\nSave Directory: \n{}".format(self.cfgs["save_dir"])
            )
            self.txt_logger.write("\n\nConfigs: \n{}\n".format(self.cfgs))

        ####### MODEL
        model = models.get_model(self.cfgs)
        if len(self.cfgs["gpu"]) > 1:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
            self.device = torch.device("cuda:{}".format(self.cfgs["local_rank"]))
            self.model = model.to(self.device)
            self.model = DistributedDataParallel(
                self.model,
                device_ids=[self.cfgs["local_rank"]],
                output_device=self.cfgs["local_rank"],
            )
        else:
            self.device = torch.device("cuda:{}".format(self.cfgs["local_rank"]))
            self.model = model.to(self.device)

        ####### Data

        train_dataset = inputs.get_dataset(self.cfgs, mode="train")
        if len(self.cfgs["gpu"]) > 1:
            train_sampler = DistributedSampler(
                train_dataset,
                num_replicas=len(self.cfgs["gpu"]),
                rank=self.cfgs["local_rank"],
            )
        else:
            train_sampler = None

        self.train_loader = DataLoader(
            dataset=train_dataset,
            batch_size=self.cfgs["batch_size"],
            num_workers=self.cfgs["num_workers"],
            pin_memory=True,
            drop_last=False,
            collate_fn=inputs.get_collater(),
            sampler=train_sampler,
        )

        # if self.do_logging:
        #     self.txt_logger.write("\nDataset: ")
        #     self.txt_logger.write(
        #         "\nTRAIN Abnormal/Normal: {}/{}".format(
        #             len(train_dataset.abnormal_meta_df),
        #             len(train_dataset.normal_meta_df),
        #         )
        #     )

        ####### Opts
        self.optimizer = opts.get_optimizer(self.cfgs, self.model.parameters())
        self.scheduler = opts.get_scheduler(self.cfgs, self.optimizer)
        self.grad_scaler = GradScaler(enabled=self.cfgs["use_amp"])

        ####### Validator
        self.validator = Validator(self.cfgs, self.device)
        # if self.do_logging:
        #     self.txt_logger.write(
        #         "\nVAL   Abnormal/Normal: {}/{}".format(
        #             len(self.validator.val_loader.dataset.abnormal_meta_df),
        #             len(self.validator.val_loader.dataset.normal_meta_df),
        #         )
        #     )

        # if self.cfgs["model"]["val"]["ignore_normal"]:
        #     self.txt_logger.write("\nVAL   Ignore Normal")
        #     self.validator.val_loader.dataset.meta_df = (
        #         self.validator.val_loader.dataset.abnormal_meta_df
        #     )

    def do_train(self):

        ####### Setup Train
        self.epoch, self.iter, self.resume_epoch = 0, 0, 0
        self.tot_val_record = {
            "best": {"det_recl": -1, "det_prec": -1, "det_f1": -1, "loss": np.inf}
        }

        if self.cfgs["model"]["train"]["resume_train"]:
            with open(
                os.path.join(self.cfgs["save_dir"], "tot_val_record.pkl"), "rb"
            ) as f:
                self.tot_val_record = pickle.load(f)
                self.iter, self.resume_epoch = (
                    self.tot_val_record["best"]["iteration"],
                    self.tot_val_record["best"]["epoch"],
                )
                resume_model_dir = os.path.join(
                    self.cfgs["save_dir"], "epoch_{}.pt".format(self.resume_epoch)
                )
                checkpoint = torch.load(resume_model_dir)
                self.model.load_state_dict(checkpoint["model"], strict=True)
                self.optimizer.load_state_dict(checkpoint["optimizer"])
                self.grad_scaler.load_state_dict(checkpoint["scaler"])
                self.txt_logger.write("\n\nResume Training Here! \n\n")

        if self.do_logging:
            self.txt_logger.write("\n\nStart Training! \n\n")
            header_columns = ["epoch", "iter", "time", "train_loss", "val_loss"]
            header_columns += ["det_recl", "det_prec", "det_fppi", "det_f1"]
            header_columns += ["cls_auc", "cls_sens", "cls_spec"]
            header_columns += ["best_epoch"]
            self.txt_logger.log_header(header_columns)

        ####### Train
        self.start_time = time.time()
        self.endurance = 0
        for epoch in range(self.resume_epoch, self.cfgs["model"]["train"]["max_epoch"]):
            # self.train_loader.dataset.shuffle()
            # self.train_loader.dataset.meta_df = (
            #     self.train_loader.dataset.abnormal_meta_df
            # )

            self.one_epoch_steps = len(self.train_loader)
            self.display_step = (
                self.one_epoch_steps // self.cfgs["model"]["train"]["display_interval"]
            )

            self.epoch = epoch
            if self.endurance > self.cfgs["model"]["train"]["endurance"]:
                if self.do_logging:
                    self.txt_logger.write(
                        "\nStop training! No more performance gain expected!"
                    )
                    best_epoch = self.tot_val_record["best"]["epoch"]
                    self.txt_logger.write(
                        "\n\nBest saved at: {}, {} epoch\n\n".format(
                            self.cfgs["save_dir"], best_epoch
                        )
                    )
                break
            self.train_val_one_epoch()

    def train_val_one_epoch(self):

        self.optimizer.zero_grad()
        self.model.train()

        t0 = time.time()

        for i, data in enumerate(self.train_loader):
            t1 = time.time()
            img = data["img"].permute(0, 3, 1, 2).to(self.device)
            logit = self.model(img)

            t2 = time.time()

            # FIXME: GPU Util이 안 나온다
            loss = opts.calc_loss(self.cfgs, self.device, data, logit)

            t3 = time.time()

            self.grad_scaler.scale(loss).backward()
            self.grad_scaler.step(self.optimizer)
            self.grad_scaler.update()

            self.optimizer.zero_grad()

            t4 = time.time()

            # NOTE: Try to avoid excessive CPU-GPU synchronization (.item() calls, or printing values from CUDA tensors).

            if self.do_logging:
                loss = loss.detach().item()
                take_time = tools.convert_time(time.time() - self.start_time)
                train_logs = [loss, "-"]
                self.txt_logger.log_result(
                    [self.epoch, "{}/{}".format(i, self.one_epoch_steps), take_time]
                    + train_logs
                )
                self.tb_writer.write_scalars(
                    {"loss": {"train loss": loss}},
                    self.iter,
                )

                if self.iter % self.display_step == 0:
                    # Visualize
                    # Find abnormal
                    for viz_bi in range(len(data["fp"])):
                        if data["bbox"][viz_bi, 0, -1] != -1:
                            break

                    with torch.no_grad():
                        self.model.eval()
                        det_preds_viz = (
                            self.model(img, mode="viz")["preds"][viz_bi]
                            .detach()
                            .cpu()
                            .numpy()
                        )

                        if len(det_preds_viz) != 0:
                            # sigmoid
                            det_preds_viz[:, -1] = 1 / (
                                1 + np.exp(-1 * det_preds_viz[:, -1])
                            )
                        else:
                            det_preds_viz = np.ones((1, 6)) * -1

                        det_anns_viz = data["bbox"][viz_bi].numpy()

                        self.tb_writer.write_images(
                            data["fp"][viz_bi],
                            data["img"][viz_bi].numpy(),
                            det_preds_viz,
                            det_anns_viz,
                            self.iter,
                            "train",
                        )
                        self.model.train()

            self.iter += 1

            lr0 = self.cfgs["model"]["opts"]["learning_rate"]
            wep = self.cfgs["model"]["opts"]["warmup_epoch"]
            if self.epoch < wep:
                for pg in self.optimizer.param_groups:
                    pg["lr"] = lr0 / wep * (self.epoch + i / self.one_epoch_steps)
            else:
                if not self.scheduler is None:
                    self.scheduler.step(self.epoch - wep + i / self.one_epoch_steps)

            t5 = time.time()
            if self.cfgs["do_profiling"]:
                print("\ndata", t1 - t0)
                print("forward", t2 - t1)
                print("calc loss", t3 - t2)
                print("backward", t4 - t3)
                print("logging", t5 - t4)
            t0 = t5

        if self.epoch > self.cfgs["model"]["val"]["ignore_epoch"]:

            # Do Validation
            val_record, val_viz = self.validator.do_validate(self.model)
            self.tot_val_record[str(self.epoch + 1)] = val_record
            val_best = val_record[self.cfgs["model"]["val"]["best"]]

            # Save Model
            select_metric = self.cfgs["model"]["val"]["best"]
            val_improved = False
            if select_metric == "loss":
                if val_best < self.tot_val_record["best"][select_metric]:
                    val_improved = True
            elif select_metric == "det_f1":
                if val_best > self.tot_val_record["best"][select_metric]:
                    val_improved = True

            if val_improved:
                checkpoint = {
                    "epoch": self.epoch,
                    "model": self.model.state_dict(),
                    "optimizer": self.optimizer.state_dict(),
                    "scaler": self.grad_scaler.state_dict(),
                }
                model_name = os.path.join(
                    self.cfgs["save_dir"], "epoch_" + str(self.epoch + 1) + ".pt"
                )
                torch.save(checkpoint, model_name)
                self.tot_val_record["best"] = val_record
                self.tot_val_record["best"]["epoch"] = self.epoch + 1
                self.tot_val_record["best"]["iteration"] = self.iter
                self.endurance = 0
            else:
                self.endurance += 1

            if self.do_logging:
                take_time = utils.tools.convert_time(time.time() - self.start_time)
                vloss = val_record["loss"]
                vbest_epoch = self.tot_val_record["best"]["epoch"]
                metric_keys = ["det_recl", "det_prec", "det_fppi", "det_f1"]
                metric_keys += ["cls_auc", "cls_sens", "cls_spec"]
                val_logs = [vloss] + [val_record[k] for k in metric_keys]
                self.txt_logger.log_result(
                    [self.epoch + 1, self.iter, take_time, loss]
                    + val_logs
                    + [vbest_epoch],
                    txt_write=True,
                )
                self.txt_logger.write("\n", txt_write=True)
                self.tb_writer.write_images(
                    val_viz["fp"],
                    val_viz["img"],
                    val_viz["pred"],
                    val_viz["ann"],
                    self.iter,
                    "val",
                )

                self.tb_writer.write_scalars(
                    {
                        "metrics": {
                            "{}".format(key): val_record[key] for key in metric_keys
                        }
                    },
                    self.iter,
                )
                self.tb_writer.write_scalars({"loss": {"val loss": vloss}}, self.iter)

                with open(
                    os.path.join(self.cfgs["save_dir"], "tot_val_record.pkl"), "wb"
                ) as f:
                    pickle.dump(self.tot_val_record, f)
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu
    logger = get_logger(args.logging_file)
    logger.info("Use GPU: {} for training".format(args.gpu))

    args.rank = args.rank * ngpus_per_node + gpu
    torch.distributed.init_process_group(backend="nccl",
                                         init_method=args.dist_url,
                                         world_size=args.world_size,
                                         rank=args.rank)

    epochs = args.epochs
    input_size = args.input_size
    resume_epoch = args.resume_epoch
    initializer = KaimingInitializer()
    zero_gamma = ZeroLastGamma()
    mix_precision_training = args.mix_precision_training
    is_first_rank = True if args.rank % ngpus_per_node == 0 else False

    batches_pre_epoch = args.num_training_samples // (args.batch_size *
                                                      ngpus_per_node)
    lr = 0.1 * (args.batch_size * ngpus_per_node //
                32) if args.lr == 0 else args.lr

    model = get_model(models, args.model)

    model.apply(initializer)
    if args.last_gamma:
        model.apply(zero_gamma)
        logger.info('Apply zero last gamma init.')

    if is_first_rank and args.model_info:
        summary(model, torch.rand((1, 3, input_size, input_size)))

    parameters = model.parameters() if not args.no_wd else no_decay_bias(model)
    if args.sgd_gc:
        logger.info('Use SGD_GC optimizer.')
        optimizer = SGD_GC(parameters,
                           lr=lr,
                           momentum=args.momentum,
                           weight_decay=args.wd,
                           nesterov=True)
    else:
        optimizer = optim.SGD(parameters,
                              lr=lr,
                              momentum=args.momentum,
                              weight_decay=args.wd,
                              nesterov=True)

    lr_scheduler = CosineWarmupLr(optimizer,
                                  batches_pre_epoch,
                                  epochs,
                                  base_lr=args.lr,
                                  warmup_epochs=args.warmup_epochs)

    # dropblock_scheduler = DropBlockScheduler(model, batches_pre_epoch, epochs)

    if args.lookahead:
        optimizer = Lookahead(optimizer)
        logger.info('Use lookahead optimizer.')

    torch.cuda.set_device(args.gpu)
    model.cuda(args.gpu)
    args.num_workers = int(
        (args.num_workers + ngpus_per_node - 1) / ngpus_per_node)

    if args.mix_precision_training and is_first_rank:
        logger.info('Train with FP16.')

    scaler = GradScaler(enabled=args.mix_precision_training)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])

    Loss = nn.CrossEntropyLoss().cuda(args.gpu) if not args.label_smoothing else \
        LabelSmoothingLoss(args.classes, smoothing=0.1).cuda(args.gpu)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    if args.autoaugment:
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(input_size),
            transforms.RandomHorizontalFlip(),
            ImageNetPolicy,
            transforms.ToTensor(),
            normalize,
        ])
    else:
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(input_size),
            # Cutout(),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.4, 0.4, 0.4),
            transforms.ToTensor(),
            normalize,
        ])

    val_transform = transforms.Compose([
        transforms.Resize(int(input_size / 0.875)),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        normalize,
    ])

    train_set = ImageNet(args.data_path,
                         split='train',
                         transform=train_transform)
    val_set = ImageNet(args.data_path, split='val', transform=val_transform)

    train_sampler = DistributedSampler(train_set)
    train_loader = DataLoader(train_set,
                              args.batch_size,
                              False,
                              pin_memory=True,
                              num_workers=args.num_workers,
                              drop_last=True,
                              sampler=train_sampler)
    val_loader = DataLoader(val_set,
                            args.batch_size,
                            False,
                            pin_memory=True,
                            num_workers=args.num_workers,
                            drop_last=False)

    if resume_epoch > 0:
        loc = 'cuda:{}'.format(args.gpu)
        checkpoint = torch.load(args.resume_param, map_location=loc)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scaler.load_state_dict(checkpoint['scaler'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        print("Finish loading resume param.")

    torch.backends.cudnn.benchmark = True

    top1_acc = metric.Accuracy(name='Top1 Accuracy')
    top5_acc = metric.TopKAccuracy(top=5, name='Top5 Accuracy')
    loss_record = metric.NumericalCost(name='Loss')

    for epoch in range(resume_epoch, epochs):
        tic = time.time()
        train_sampler.set_epoch(epoch)
        if not args.mixup:
            train_one_epoch(model, train_loader, Loss, optimizer, epoch,
                            lr_scheduler, logger, top1_acc, loss_record,
                            scaler, args)
        else:
            train_one_epoch_mixup(model, train_loader, Loss, optimizer, epoch,
                                  lr_scheduler, logger, loss_record, scaler,
                                  args)
        train_speed = int(args.num_training_samples // (time.time() - tic))
        if is_first_rank:
            logger.info(
                'Finish one epoch speed: {} samples/s'.format(train_speed))
        test(model, val_loader, Loss, epoch, logger, top1_acc, top5_acc,
             loss_record, args)

        if args.rank % ngpus_per_node == 0:
            checkpoint = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scaler': scaler.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
            }
            torch.save(
                checkpoint, '{}/{}_{}_{:.5}.pt'.format(args.save_dir,
                                                       args.model, epoch,
                                                       top1_acc.get()))
예제 #7
0
    if args.chkpoint:
        chk = torch.load(args.chkpoint, map_location=device)
    elif args.finetune:
        if args.chkpointft:
            chk = torch.load(args.chkpointft, map_location=device)
        else:
            sys.exit("Finetune can't be performed if chkpointft not supplied")
    else:
        chk = None
        start_epoch = 0
        best_loss = float('-inf') if IsNegLoss else float('inf')

    if chk is not None:
        model.load_state_dict(chk['state_dict'])
        optimizer.load_state_dict(chk['optimizer'])
        scaler.load_state_dict(chk['AMPScaler'])  
        best_loss = chk['best_loss']  
        start_epoch = chk['epoch'] + 1
        iterations = chk['iterations']
        main_train_epcoh = (chk['main_train_epoch'] + 1) if 'main_train_epoch' in chk else start_epoch #only be used for finetune

    if args.finetune:
        if args.fteprt:
            args.epochs = int((main_train_epcoh*(1+args.fteprt)))
        else:
            args.iterations = int(iterations*args.ftitrt)
            n_ft_ep = int(args.iterations // len(train_loader))
            args.epochs = main_train_epcoh + n_ft_ep

    if args.epochs is None:
        args.epochs = int(args.iterations // len(train_loader) + 1)
예제 #8
0
class SelfSupervisionTask(ClassificationTask):
    """
    A task prepares and holds all the components of a training like optimizer, datasets,
    dataloaders, losses, meters etc. Task also contains the variable like training iteration,
    epoch number etc. that are updated during the training.

    We prepare every single component according to the parameter settings user wants
    and specified in the yaml config file.

    Task also supports 2 additional things:
    1) converts the model BatchNorm layers to the synchronized batchnorm
    2) sets mixed precision (apex and pytorch both supported)
    """

    def __init__(self, config: AttrDict):
        super().__init__()
        self.config = config
        self.checkpoint_path = None

        # Register the task to the proper device (cpu, gpu, ...)
        self.set_device()

        self.checkpoint_folder = None
        self.checkpoint = None
        self.available_splits = []
        self.base_loss = None
        self.meters = None
        self.datasets = None
        self.phases = []
        self.hooks = []
        self.base_model = None
        self.optimizer = None
        self.amp_args = None
        self.amp_type = None
        self.amp_grad_scaler = None
        self.data_and_label_keys = []
        self.set_amp_args()
        self._enable_manual_gradient_reduction = None
        # total number of parameter updates applied to the model by optimizer
        self.num_updates = 0
        # measure time of several training components (data, forward, backward etc..)
        self.perf_stats = None
        # total number of phases including test + train
        self.num_phases = -1  # set by the trainer
        # number of train only phases
        self.num_train_phases = -1  # set by the prepare method
        # number or train epochs set to num_train_phases
        self.num_epochs = -1  # set by the trainer
        # total number of "training" iterations. Inferred from dataloader length and
        # num_train_phases
        self.max_iteration = -1  # set by trainer
        # Current phase id (includes train/test). Starts from 0
        self.phase_idx = -1
        # id of the current training phase training is at. Starts from 0
        self.train_phase_idx = -1  # set by trainer
        # metrics stored during the training.
        self.metrics = {}  # set by the trainer
        self.start_time = -1  # set by trainer
        # time of each batch in training and testing. This can be used to get average
        # batch time etc. batch_time is appended after every parameter update.
        self.batch_time = []  # set by trainer
        # we maintain and store the iteration in the state itself. It counts
        # total number of iterations we do in training phases. Updated
        # after every forward pass of training step.
        # Starts from 1
        self.iteration = 0
        # collect how many total iterations we make irrespective of train/test phase.
        # Useful for debugging purposes. Starts from 1.
        self.local_iteration_num = -1  # set by trainer
        # for every phase, record the start time. Reset at the beginning of each phase
        # by SetDataSamplerEpochHook hook.
        self.phase_start_time = -1  # set by the hook at start of each epoch or phase
        # for every phase, record the number of batches seen. Incremented after every
        # forward pass. Reset at the start of each phase by
        # SetDataSamplerEpochHook hook. Useful for debugging.
        self.batches = -1  # set by the hook at start of each epoch or phase
        # loss curve. Reset at start of each phase/epoch by SetDataSamplerEpochHook hook.
        self.losses = []  # set by the hook at start of each epoch or phase
        # set the bucket_cap_mb for gradient reduction. This can be tuned to overlap
        # communication as much as possible
        self.set_ddp_bucket_cap_mb()
        self.use_gpu = self.device.type == "cuda"
        # optionally save the exponential moving average (ema) of the base_model.
        # and/or run the meters on the ema of the base_model.
        self.ema_model = None
        self.ema_meters = []

    def set_device(self):
        """
        Set the training device: whether gpu or cpu. We use the self.device
        in the rest of the workflow to determine if we should do cpu only training
        or use gpu. set MACHINE.DEVICE = "gpu" or "cpu"
        """
        try:
            self.device = torch.device(
                "cuda" if self.config.MACHINE.DEVICE == "gpu" else "cpu"
            )
        except AttributeError:
            self.device = torch.device("cuda")

    def set_ddp_bucket_cap_mb(self):
        """
        PyTorch DDP supports setting the bucket_cap_mb for all reduce. Tuning
        this parameter can help with the speed of the model. We use the default
        pytorch value of 25MB.
        """
        self.ddp_bucket_cap_mb = self.config.DATA.DDP_BUCKET_CAP_MB
        assert self.ddp_bucket_cap_mb > 0, "bucket_cap_mb must be positive"

    def set_available_splits(self):
        """
        Given the data settings, we determine if we are using both train and test
        datasets. If TEST_MODEL=true, we will add the test to the available_splits.
        If TEST_ONLY=false, we add train to the split as well.
        """
        if self.config.TEST_MODEL:
            self.available_splits.append("TEST")
        if not self.config.TEST_ONLY:
            self.available_splits.append("TRAIN")
        return self

    def set_amp_args(self):
        """
        Two automatic mixed precision implementations are available: Apex's and PyTorch's.

        - If Apex's AMP is enabled, amp_args is a dictionary containing arguments
        to be passed to amp.initialize. Set to None to disable amp.
        To enable mixed precision training, pass amp_args={"opt_level": "O1"} here.
        See https://nvidia.github.io/apex/amp.html for more info.

        - If Pytorch's AMP is enabled, no arguments are needed.
        """

        if self.config.MODEL.AMP_PARAMS.USE_AMP:
            assert (
                self.device.type == "cuda"
            ), "Mixed precision is only available on CUDA devices for now"

            # This will rightly fail if the setting is not correct
            self.amp_type = AmpType[self.config.MODEL.AMP_PARAMS.AMP_TYPE.upper()]
            if self.amp_type == AmpType.APEX:
                self._init_apex_grad_scaler()
            elif self.amp_type == AmpType.PYTORCH:
                self._init_pytorch_grad_scaler()
            logging.info(f"Setting AMP: {self.amp_type} - args: {self.amp_args}")

        else:
            self.amp_args, self.amp_type = None, None
            logging.info("Not using Automatic Mixed Precision")

    def _init_apex_grad_scaler(self):
        # Check Apex availability
        if not is_apex_available():
            raise RuntimeError("Apex is not available. Can't use mixed precision")

        # "amp_args" are actually Apex Amp args
        self.amp_args = self.config.MODEL.AMP_PARAMS.AMP_ARGS
        logging.info(f"Setting AMP: using apex, args {self.amp_args}")

    def _init_pytorch_grad_scaler(self):
        if self.config["OPTIMIZER"]["name"] == "zero":
            assert is_fairscale_sharded_available(), (
                "To use ZeRO with PyTorch AMP, ShardedGradScaler() "
                "from fairscale is needed. Please upgrade fairscale"
            )
            from fairscale.optim.grad_scaler import ShardedGradScaler

            self.amp_grad_scaler = ShardedGradScaler()
            logging.info("Setting AMP: using sharded grad scaler")
        else:
            self.amp_grad_scaler = TorchGradScaler()
            logging.info("Setting AMP: using pytorch grad scaler")

    def set_checkpoint_path(self, checkpoint_path: str):
        """
        Set the checkpoint path for the training
        """
        self.checkpoint_path = checkpoint_path

    def set_checkpoint_folder(self, checkpoint_folder: str):
        """
        Set the checkpoint folder for the training
        """
        self.checkpoint_folder = checkpoint_folder

    def set_iteration(self, iteration):
        """
        Set the iteration number.
        we maintain and store the iteration in the state itself. It counts
        total number of iterations we do in training phases. Updated
        after every forward pass of training step.
        Starts from 1
        """
        assert iteration >= 0, "Iteration number must be positive"
        self.iteration = iteration

    @property
    def enable_manual_gradient_reduction(self) -> bool:
        """
        Lazily initial the enable flag once when model is not None.
        """
        if self._enable_manual_gradient_reduction is None and self.model is not None:
            self.set_manual_gradient_reduction()
        if self._enable_manual_gradient_reduction:
            return True
        return False

    def set_manual_gradient_reduction(self) -> None:
        """
        Called during __init__ to set a flag if manual gradient reduction is enabled.
        """
        assert self.model is not None
        self._enable_manual_gradient_reduction = manual_gradient_reduction(
            self.model, self.config["DISTRIBUTED"]["MANUAL_GRADIENT_REDUCTION"]
        )
        if self._enable_manual_gradient_reduction:
            logging.info("Enabling manual gradient reduction")

    @classmethod
    def from_config(cls, config):
        """
        Create the task from the yaml config input.
        """
        test_only = config.TEST_ONLY

        return (
            cls(config)
            .set_available_splits()
            .set_test_only(test_only)
            .set_epoch_phase_info()
        )

    def set_epoch_phase_info(self):
        # In case optimizer doesn't exist. E.g. for feature extraction.
        optimizer = getattr(self.config, "OPTIMIZER", {})
        self.num_epochs = getattr(optimizer, "num_epochs", 1)
        self.num_train_phases_per_epoch = getattr(
            self.config["DATA"]["TRAIN"], "TRAIN_PHASES_PER_EPOCH", 1
        )
        self.num_train_phases = (
            self.config["OPTIMIZER"]["num_epochs"] * self.num_train_phases_per_epoch
        )

        return self

    # We keep the function because this is used by hooks like checkpoint etc.
    def get_config(self):
        """
        Utility function to store and use the config that was used for the given
        training.
        """
        return {"config": self.config}

    def _build_phases(self):
        """
        Returns list of phases from config. These phases will look like:
        {
          train: is this a train or test phase (bool)?
        }
        If this is a test only run, then only test phases will be
        generated, if this is a training run, then #phases = #train-phases + #test-phases,
        interleaved. We also add the test phases every TEST_EVERY_NUM_EPOCH if
        we don't want the tst to run after every test phase.
        """
        if not self.config["TEST_ONLY"]:
            phases = [{"train": True} for _ in range(self.num_train_phases)]
            # whether the model is train or test only. If the model is not test
            # only, then whether we do test as well or not, is decided from the
            # config file.
            test_every = (
                self.config.get("TEST_EVERY_NUM_EPOCH", 1)
                * self.num_train_phases_per_epoch
            )
            output_phases = []
            for idx, phase in enumerate(phases):
                output_phases.append(phase)

                if idx % test_every == 0 or idx == (len(phases) - 1):
                    output_phases.append({"train": False})
            # we do a little surgery here. Either the phases are test only or
            # [train + test] both interleaved. If we don't want the model to be tested
            # at all (which is sometimes the case in self-supervised learning), we
            # remove the test phases.
            if not self.config["TEST_MODEL"]:
                output_phases = [phase for phase in output_phases if phase["train"]]
        else:
            output_phases = [{"train": False} for _ in range(self.num_train_phases)]
        return output_phases

    def build_datasets(self, current_train_phase_idx=0):
        """
        Get the datasets for the data splits we will use in the training. The
        set_available_splits variable determines the splits used in the training.
        """
        datasets, data_and_label_keys = {}, {}
        for split in self.available_splits:
            datasets[split.lower()] = build_dataset(
                cfg=self.config,
                split=split,
                current_train_phase_idx=current_train_phase_idx,
            )
            data_and_label_keys["input"] = self.config.DATA[split].INPUT_KEY_NAMES
            data_and_label_keys["target"] = self.config.DATA[split].TARGET_KEY_NAMES

        return datasets, data_and_label_keys

    def build_dataloaders(
        self, pin_memory: bool, current_train_phase_idx=0
    ) -> torch.utils.data.DataLoader:
        """
        Build PyTorch dataloaders for all the available_splits. By default, we construct the
        standard PyTorch Dataloader and allow setting all dataloader options.
        """
        # Gives sampler same seed for entire distributed group as per pytorch documentation.
        sampler_seed = self.config["SEED_VALUE"]

        loaders = {
            split.lower(): build_dataloader(
                dataset=self.datasets[split.lower()],
                dataset_config=self.config["DATA"][split],
                num_dataloader_workers=self.config.DATA.NUM_DATALOADER_WORKERS,
                pin_memory=pin_memory,
                multi_processing_method=self.config.MULTI_PROCESSING_METHOD,
                device=self.device,
                sampler_seed=sampler_seed,
                split=split.lower(),
            )
            for split in self.available_splits
        }

        return loaders

    def get_global_batchsize(self):
        """
        Return global batchsize used in the training across all the trainers.
        We check what phase we  are in (train or test) and get the dataset
        used in that phase. We call get_global_batchsize() of the dataset.
        """
        for phase_type in self.datasets:
            if phase_type.lower() == self.phase_type.lower():
                return self.datasets[phase_type].get_global_batchsize()
        raise ValueError(f"{self.phase_type} not found in self.datasets")

    def _build_optimizer(self):
        """
        Build optimizers using the optimizer settings specified by user.
        For SGD, we support LARC as well. In order to use LARC, Apex must
        be installed.
        """
        optimizer_config = self.config["OPTIMIZER"]
        if optimizer_config.use_larc and optimizer_config.name != "sgd_fsdp":
            assert is_apex_available(), "Apex must be available to use LARC"
        optim = build_optimizer(optimizer_config)
        return optim

    def _build_optimizer_schedulers(self):
        """
        Build the param schedulers to be used in training.
        """
        return build_optimizer_schedulers(self.config["OPTIMIZER"])

    def _build_loss(self):
        """
        Build the loss used in training. Supports all PyTorch losses
        and custom defined losses.

        For some losses that require memory banks (for example in info_nce loss),
        we need to store the size of data as we use it to allocate memory.
        Since dataset size is not known at the time of config parsing, we set
        the data size parameter here.
        """
        # in some cases like memory bank, we need to store the size of data
        # as we use it to allocate memory. Hence we set that parameter here.
        logging.info("Building loss...")
        loss_name = self.config.LOSS["name"]
        assert loss_name in list(self.config.LOSS.keys()), (
            f"Loss {loss_name} params unknown. The loss name and the param dict "
            f"key name should match. Known: {list(self.config.LOSS.keys())}"
        )
        loss_config = self.config.LOSS[loss_name]
        if "num_train_samples" in loss_config.keys():
            for split in self.available_splits:
                if split == "TRAIN":
                    loss_config["num_train_samples"] = len(self.datasets["train"])
                if split == "TEST":
                    loss_config["num_train_samples"] = len(self.datasets["test"])
        loss_config["name"] = loss_name
        loss = build_loss(loss_config)
        return loss

    def _build_meters(self):
        """
        Returns meters for task.
        """
        meter_names = self.config["METERS"].get("names", [])

        if not meter_names:
            return []

        meters = []
        for meter_name in meter_names:
            meter_params = self.config["METERS"][meter_name]
            meter_config = {"name": meter_name, **meter_params}
            meters.append(build_meter(meter_config))

        return meters

    def _restore_model_weights(self, model, strict: bool = False):
        """
        If using a weights file to initialize the model, we load the weights
        and initialize the model. Since the weights file specified
        by user might not be VISSL trained weights, we expose several config
        options like APPEND_PREFIX, etc to allow successful loading of the weights.
        See MODEL.WEIGHTS_INIT description in vissl/config/defaults.yaml for details.
        """
        params_from_file = self.config["MODEL"]["WEIGHTS_INIT"]
        init_weights_path = params_from_file["PARAMS_FILE"]
        assert init_weights_path, "Shouldn't call this when init_weight_path is empty"
        logging.info(f"Initializing model from: {init_weights_path}")

        if g_pathmgr.exists(init_weights_path):
            checkpoint = CheckpointLoader.load_and_broadcast_init_weights(
                checkpoint_path=init_weights_path, device=torch.device("cpu")
            )
            logging.info(f"Checkpoint loaded: {init_weights_path}...")
            model.init_model_from_weights_params_file(
                self.config, checkpoint, strict=strict
            )
        return model

    def _build_model(self, strict_load: bool = False):
        """
        - Builds and returns model used for task. The returned model is not copied to
          gpu yet (if using gpu) and neither wrapped with DDP yet. This is done later
          by self.prepare()

        - We also convert the model BatchNorm layers to SyncBatchNorm if user
          has set the config option. We support PyTorch and Apex SyncBatchNorms
          both.

        - If the model is set to be in evaluation model and the full model must be frozen,
          we freeze the model.

        - If the model must be initialized from a checkpoint or user passed weights file
          we initialize the model from the checkpoint or the weights.
        """
        logging.info("Building model....")

        # Instantiate the raw model as specified
        model = build_model(self.config["MODEL"], self.config["OPTIMIZER"])

        # Convert the BatchNorm layers to SyncBatchNorm if needed
        # Both Apex and Pytorch SyncBatchNorms are GPU only
        if (
            self.config["MODEL"]["SYNC_BN_CONFIG"]["CONVERT_BN_TO_SYNC_BN"]
            and self.config["MACHINE"]["DEVICE"] == "gpu"
        ):
            model = convert_sync_bn(self.config, model)

        # Enforce eval mode, no matter what the prior tranforms have done.
        # For instance apex converts batch-norms and sets `requires_grad` to True
        if self.config["MODEL"]["FEATURE_EVAL_SETTINGS"]["EVAL_MODE_ON"]:
            if self.config["MODEL"]["FEATURE_EVAL_SETTINGS"]["FREEZE_TRUNK_ONLY"]:
                logging.info(
                    "config.MODEL.FEATURE_EVAL_SETTINGS.FREEZE_TRUNK_ONLY=True, "
                    "will freeze trunk..."
                )
                model.freeze_trunk()
            elif self.config["MODEL"]["FEATURE_EVAL_SETTINGS"]["FREEZE_TRUNK_AND_HEAD"]:
                logging.info(
                    "config.MODEL.FEATURE_EVAL_SETTINGS.FREEZE_TRUNK_AND_HEAD=True, will "
                    "freeze trunk and head..."
                )
                model.freeze_head_and_trunk()

        # assert that if the user set the PARAMS_FILE, it must exist and be valid.
        if (
            self.checkpoint_path is None
            and self.config["MODEL"]["WEIGHTS_INIT"]["PARAMS_FILE"]
        ):
            assert g_pathmgr.exists(
                self.config["MODEL"]["WEIGHTS_INIT"]["PARAMS_FILE"]
            ), "Specified PARAMS_FILE does NOT exist"
        # If we want to initialize the model in case of finetuning or evaluation,
        # we do it here. But we check that there is no checkpoint existing before
        # This is important in cases when the model training dies.
        if (
            self.checkpoint_path is None
            and self.config["MODEL"]["WEIGHTS_INIT"]["PARAMS_FILE"]
            and g_pathmgr.exists(self.config["MODEL"]["WEIGHTS_INIT"]["PARAMS_FILE"])
        ):
            model = self._restore_model_weights(model, strict=strict_load)

        return model

    def init_distributed_data_parallel_model(self):
        """
        This method overloads the ClassificationTask class's method from ClassyVision.
        """
        if not is_distributed_training_run():
            return

        for module in self.base_model.modules():
            if isinstance(module, FullyShardedDataParallel):
                raise ValueError(
                    "DistributedDataParallel should not be used"
                    "with a FullyShardedDataParallel model.\n"
                    "Please set config.TRAINER.TASK_NAME='self_supervision_fsdp_task'"
                )

        super().init_distributed_data_parallel_model()

    def set_epoch(
        self, phase_type: str, epoch: int, start_iter: int, train_phase_idx: int
    ):
        if hasattr(self.dataloaders[phase_type], "sampler"):
            sampler = self.dataloaders[phase_type].sampler
            # (Re-)Shuffle data: set epoch of distributed (or fairstore) sampler
            # Resume from the iteration if valid
            self.set_train_epoch_start_iter(sampler, epoch, start_iter, train_phase_idx)
            print_sampler_config(sampler)

        # call set_epoch and set_start_iter for AirstoreDataset since it handles
        # shuffle and sample skipping behavior internally
        dataset = self.datasets[phase_type]
        if hasattr(dataset, "data_objs"):
            for data_obj in dataset.data_objs:
                self.set_train_epoch_start_iter(
                    data_obj, epoch, start_iter, train_phase_idx
                )

    def set_train_epoch_start_iter(
        self, dataset_or_sampler, epoch: int, start_iter: int, train_phase_idx: int
    ):
        # (Re-)Shuffle data: set epoch of distributed (or fairstore) sampler
        if hasattr(dataset_or_sampler, "set_epoch"):
            dataset_or_sampler.set_epoch(epoch)
        # Resume from the iteration if valid
        if hasattr(dataset_or_sampler, "set_start_iter"):
            dataset_or_sampler.set_start_iter(start_iter)

        if hasattr(dataset_or_sampler, "set_train_phase_idx"):
            dataset_or_sampler.set_train_phase_idx(train_phase_idx)

    def num_phase_samples(self, phase_type: str) -> int:
        """
        Number of samples in a phase.
        """
        dataset = self.datasets[phase_type.lower()]
        return dataset.num_samples()

    def _compute_start_iter_from_checkpoint(self, phase_type) -> int:
        # used for calculating the start iteration (count from current epoch) when resuming
        # from checkpoint
        if self.checkpoint is None or self.checkpoint["iteration"] <= 0:
            return 0

        num_iters_in_epochs = len(self.dataloaders[phase_type])
        num_epochs = self.checkpoint["train_phase_idx"] + 1
        num_train_iters_done = num_epochs * num_iters_in_epochs
        return self.checkpoint["iteration"] - num_train_iters_done

    def recreate_data_iterator(
        self,
        phase_type: str,
        epoch: int,
        compute_start_iter: bool,
        train_phase_idx: int,
    ):
        """
        Recreate data iterator (including multiprocessing workers) and destroy the
        previous iterators.

        This is called when we load a new checkpoint or when phase changes during
        the training (one epoch to the next).
        DataSampler may need to be informed on those events to update the
        epoch and start_iteration so that the data is deterministically shuffled,
        so we call them here.
        """
        start_iter = 0
        if compute_start_iter:
            start_iter = self._compute_start_iter_from_checkpoint(phase_type)

        self.set_epoch(phase_type, epoch, start_iter, train_phase_idx)

        # Gives sampler same seed for entire distributed group as per pytorch documentation.
        sampler_seed = self.config["SEED_VALUE"]
        dataset = self.datasets[phase_type]

        # For OSS, this will always return false.
        # Otherwise, we will rebuild the dataloader after every phase.
        if dataset.rebuild_dataloader():
            dataloader = build_dataloader(
                dataset=dataset,
                dataset_config=self.config.DATA[phase_type.upper()],
                num_dataloader_workers=self.config.DATA.NUM_DATALOADER_WORKERS,
                pin_memory=self.config.DATA.PIN_MEMORY,
                multi_processing_method=self.config.MULTI_PROCESSING_METHOD,
                device=self.device,
                sampler_seed=sampler_seed,
                split=phase_type,
            )

            # delete old dataloader and reset it.
            del self.dataloaders[phase_type]
            gc.collect()
            self.dataloaders[phase_type] = dataloader

        # delete old dataiterator and reset it.
        del self.data_iterator
        gc.collect()
        self.data_iterator = iter(self.dataloaders[phase_type])

    def _set_classy_state(self, state):
        """
        We load/set the model state setting here to resume correctly from the
        specified state. Usually called when resuming training from a previous
        model checkpoint.
        We set the model phase (train or eval), model weights,
        copy the model to correct device, initialize meters, initialize optimizers
        initialize amp state, set loss state, set the train phase number, iteration,
        recreate data iterators, etc.
        """
        logging.info("=======Updating classy state_dict from checkpoint=======")
        # here we load the state specific things only. The other extra variables
        # are init from the checkpoint in the trainer step.
        self.train = state["train"]
        self.base_model.set_classy_state(state["base_model"])
        # We need to set the model on correct device here unlike in the case of
        # training from scratch. The optimizer looks at the model parameters like
        # momentum etc. for getting the device info. Since in case of scratch
        # training, we don't have those and the optimizer just gets the inputs
        # as cuda inputs from the model, it can work. However, when we load from
        # a checkpoint, we already have these parameters and the type is CPU
        # (since the model isn't copied to gpu yet). The copy_model_to_gpu()
        # doesn't modify optimizer params device. The optimizer is constructed
        # with the CPU inputs. When the model runs, it rather sends CUDA.
        self.base_model.to(self.device)

        self._set_ema_model_state(state)

        for meter, meter_state in zip(self.meters, state["meters"]):
            meter.set_classy_state(meter_state)
        self.optimizer.set_classy_state(state["optimizer"])

        # restore amp state. It's called after amp.initialize is done.
        if "amp" in state:
            if self.amp_type == AmpType.APEX:
                if is_apex_available():
                    apex.amp.load_state_dict(state["amp"])
                else:
                    logging.warning(
                        "Loading a checkpoint which has amp state but apex isn't available now"
                    )
            else:
                self.amp_grad_scaler.load_state_dict(state["amp"])
        self.phase_idx = state["phase_idx"]
        self.train_phase_idx = state["train_phase_idx"]
        self.num_updates = state["num_updates"]
        self.losses = state["losses"]

        phase_type = "train" if self.train else "test"
        phase = self.phases[self.phase_idx]

        # Re-create the data iterator.
        # We are restoring from a checkpoint, which means we need to
        #   (1) set the right epoch
        #   (2) set the right start_iter
        # epoch number is `phase_idx + 1` since checkpoint's value is the epoch finished.
        # start_iter is computed in recreate_data_iterator based on iteration
        # number from the checkpoint state.
        self.recreate_data_iterator(
            phase_type,
            epoch=self.phase_idx + 1,
            compute_start_iter=True,
            train_phase_idx=self.train_phase_idx + 1,
        )

        # set the model to train or eval depending on what phase we are in
        self.base_model.train(phase["train"])

        if self.train and self.train_phase_idx >= 0:
            self.optimizer.on_epoch(self.where)

    def _set_ema_model_state(self, state):
        """
        Only used if EmaMetersHook is enabled.
        """
        if self.ema_model is not None:
            logging.info("Loading ema model")
            self.ema_model.module.set_classy_state(state["ema_model"])
            for meter, meter_state in zip(self.ema_meters, state["ema_meters"]):
                meter.set_classy_state(meter_state)

    def _update_classy_state(self, state_dict=None):
        """
        Updates classy state with the provided state dict from a checkpoint.
        state_dict = checkpoint loaded state
        """
        if state_dict is not None:
            try:
                self._set_classy_state(state_dict)
                success = True
            except Exception as e:
                logging.exception(f"Could not load the checkpoint: {e}")
                success = False
            assert success, "Update classy state from checkpoint failed."
        return self

    def _set_ddp_options(self):
        """
        set DDP options if the user has supplied them
        """
        broadcast_buffers = self.config["DISTRIBUTED"]["BROADCAST_BUFFERS"]
        if broadcast_buffers:
            logging.info(
                "Broadcast model BN buffers from primary on every forward pass"
            )
            broadcast_buffers_enum_mode = BroadcastBuffersMode.FORWARD_PASS
            self.set_distributed_options(
                broadcast_buffers_mode=broadcast_buffers_enum_mode
            )  # NOQA

    def run_hooks(self, hook_function_name: str):
        """
        Override the ClassyTask run_hook function and run the hooks whenever called
        """
        for hook in self.hooks:
            getattr(hook, hook_function_name, ClassyHook._noop)(self)

    def prepare_optimizer(self):
        """
        Constructs the optimizer using the user defined settings in the yaml config.
        The model must be on the correct device (cuda or cpu) by this point.
        """
        param_groups = get_optimizer_param_groups(
            model=self.base_model,
            model_config=self.config["MODEL"],
            optimizer_config=self.config["OPTIMIZER"],
            optimizer_schedulers=self.optimizer_schedulers,
        )
        self.optimizer.set_param_groups(param_groups)

    def prepare(self, pin_memory: bool = False):
        """
        Prepares the task:
        - dataloaders
        - model
        - copy model to correct device
        - meters
        - loss
        - optimizer
        - LR schedulers
        - AMP state
        - resume from a checkpoint if available
        """
        self.phases = self._build_phases()
        self.num_phases = len(self.phases)
        self.base_model = self._build_model()
        self._set_ddp_options()
        self.meters = self._build_meters()
        self.optimizer = self._build_optimizer()
        self.optimizer_schedulers = self._build_optimizer_schedulers()

        if self.device.type == "cuda":
            self.base_model = copy_model_to_gpu(self.base_model)

        # initialize the pytorch optimizer now since the model has been moved to
        # the appropriate device.
        self.prepare_optimizer()

        # Enable mixed precision grad scalers
        if self.amp_type == AmpType.APEX:
            # Allow Apex Amp to perform casts as specified by the amp_args.
            # This updates the model and the PyTorch optimizer (which is wrapped
            # by the ClassyOptimizer in self.optimizer).
            # NOTE: this must happen before loading the checkpoint. See
            # https://nvidia.github.io/apex/amp.html#checkpointing for more details.
            self.base_model, self.optimizer.optimizer = apex.amp.initialize(
                self.base_model, self.optimizer.optimizer, **self.amp_args
            )

        # Create EMA average of the model if hook is specified.
        ema_config = self.config["HOOKS"]["EMA_MODEL"]
        if ema_config["ENABLE_EMA_METERS"] or ema_config["SAVE_EMA_MODEL"]:
            self._create_ema_model()

        # Restore an hypothetical checkpoint
        vissl_state_dict = None
        if self.checkpoint_path is not None:
            self.checkpoint = CheckpointLoader.load_and_broadcast_checkpoint(
                checkpoint_folder=self.checkpoint_folder,
                checkpoint_path=self.checkpoint_path,
                device=torch.device("cpu"),
            )
            if self.checkpoint is not None:
                self.iteration = self.checkpoint["iteration"]
                self.local_iteration_num = self.checkpoint["iteration_num"]
                vissl_state_dict = self.checkpoint.get("classy_state_dict")
            else:
                raise ValueError(f"Could not load checkpoint: {self.checkpoint_path}")

        current_train_phase_idx = (
            vissl_state_dict["train_phase_idx"] + 1 if vissl_state_dict else 0
        )

        self.datasets, self.data_and_label_keys = self.build_datasets(
            current_train_phase_idx
        )

        # set dataset state before building dataloader, in order to capture checkpoint info.
        if vissl_state_dict and "train" in self.datasets:
            self.datasets["train"].set_classy_state(
                vissl_state_dict.get("train_dataset_iterator")
            )

        self.dataloaders = self.build_dataloaders(
            pin_memory=pin_memory, current_train_phase_idx=current_train_phase_idx
        )

        # Build base loss, move to device, and load from checkpoint if applicable
        self.base_loss = self._build_loss()
        self.base_loss = self.base_loss.to(self.device)
        if self.checkpoint and "loss" in self.checkpoint:
            self.base_loss.load_state_dict(self.checkpoint["loss"])
            logging.info("======Loaded loss state from checkpoint======")

        return self._update_classy_state(vissl_state_dict)

    def prepare_extraction(self, pin_memory: bool = False):
        """
        Prepares a light-weight task for feature extraction on multi-gpu. The model
        runs in eval mode only.
        """
        self.datasets, self.data_and_label_keys = self.build_datasets()
        self.dataloaders = self.build_dataloaders(pin_memory=pin_memory)
        # build the meters in case the extraction is for predictions.
        self.meters = self._build_meters()
        self.base_model = self._build_model(strict_load=True)
        if self.device.type == "cuda":
            self.base_model = copy_model_to_gpu(self.base_model)
        return self

    def add_dummy_layer(self):
        """
        In case of feature evaluation mode, if we are freezing both trunk and
        head, DDP won't work as there are no parameters in the model. Adding
        the dummy head will lead to features being not right. So we rather
        add the dummy layer to the model and use DDP. We copy the model to
        gpu (if using gpus) after the new dummy layer addition.
        """
        fully_frozen_model = self.base_model.is_fully_frozen_model()
        if fully_frozen_model:
            self.base_model.dummy_layer = torch.nn.Linear(4, 4)
            if self.device.type == "cuda":
                self.base_model = copy_model_to_gpu(self.base_model)

    def _create_ema_model(self):
        logging.info("Building the EMA model.")
        ema_model = build_model(self.config["MODEL"], self.config["OPTIMIZER"])
        self.ema_model = ModelEmaV2(
            ema_model,
            decay=self.config["HOOKS"]["EMA_MODEL"]["DECAY"],
            device=self.config["HOOKS"]["EMA_MODEL"]["EMA_DEVICE"],
        )
        self.ema_model.set(self.base_model)
예제 #9
0
class DeepvacTrain(Deepvac):
    def __init__(self, deepvac_config):
        super(DeepvacTrain, self).__init__(deepvac_config)
        self.initTrainParameters()
        self.initTrainContext()

    def setTrainContext(self):
        self.is_train = True
        self.is_val = False
        self.phase = 'TRAIN'
        self.dataset = self.train_dataset
        self.loader = self.train_loader
        self.batch_size = self.conf.train.batch_size
        self.net.train()
        if self.qat_net_prepared:
            self.qat_net_prepared.train()

    def setValContext(self):
        self.is_train = False
        self.is_val = True
        self.phase = 'VAL'
        self.dataset = self.val_dataset
        self.loader = self.val_loader
        self.batch_size = self.conf.val.batch_size
        self.net.eval()
        if self.qat_net_prepared:
            self.qat_net_prepared.eval()

    def initTrainContext(self):
        self.scheduler = None
        self.initOutputDir()
        self.initSummaryWriter()
        self.initCriterion()
        self.initOptimizer()
        self.initScheduler()
        self.initCheckpoint()
        self.initTrainLoader()
        self.initValLoader()

    def initTrainParameters(self):
        self.dataset = None
        self.loader = None
        self.target = None
        self.epoch = 0
        self.step = 0
        self.iter = 0
        # Creates a GradScaler once at the beginning of training.
        self.scaler = GradScaler()
        self.train_time = AverageMeter()
        self.load_data_time = AverageMeter()
        self.data_cpu2gpu_time = AverageMeter()
        self._mandatory_member_name = [
            'train_dataset', 'val_dataset', 'train_loader', 'val_loader',
            'net', 'criterion', 'optimizer'
        ]

    def initOutputDir(self):
        if self.conf.output_dir != 'output' or self.conf.output_dir != './output':
            LOG.logW(
                "According deepvac standard, you should save model files to [output] directory."
            )

        self.output_dir = '{}/{}'.format(self.conf.output_dir, self.branch)
        LOG.logI('model save dir: {}'.format(self.output_dir))
        #for DDP race condition
        os.makedirs(self.output_dir, exist_ok=True)

    def initSummaryWriter(self):
        event_dir = "{}/{}".format(self.conf.log_dir, self.branch)
        self.writer = SummaryWriter(event_dir)
        if not self.conf.tensorboard_port:
            return
        from tensorboard import program
        tensorboard = program.TensorBoard()
        self.conf.tensorboard_ip = '0.0.0.0' if self.conf.tensorboard_ip is None else self.conf.tensorboard_ip
        tensorboard.configure(argv=[
            None, '--host',
            str(self.conf.tensorboard_ip), '--logdir', event_dir, "--port",
            str(self.conf.tensorboard_port)
        ])
        try:
            url = tensorboard.launch()
            LOG.logI('Tensorboard at {} '.format(url))
        except Exception as e:
            LOG.logE(e.msg)

    def initCriterion(self):
        self.criterion = torch.nn.CrossEntropyLoss()
        LOG.logW(
            "You should reimplement initCriterion() to initialize self.criterion, unless CrossEntropyLoss() is exactly what you need"
        )

    def initCheckpoint(self):
        if not self.conf.checkpoint_suffix or self.conf.checkpoint_suffix == "":
            LOG.logI('Omit the checkpoint file since not specified...')
            return
        LOG.logI('Load checkpoint from {} folder'.format(self.output_dir))
        self.net.load_state_dict(
            torch.load(self.output_dir +
                       '/model__{}'.format(self.conf.checkpoint_suffix),
                       map_location=self.device))
        state_dict = torch.load(
            self.output_dir +
            '/checkpoint__{}'.format(self.conf.checkpoint_suffix),
            map_location=self.device)
        self.optimizer.load_state_dict(state_dict['optimizer'])
        if self.scheduler:
            self.scheduler.load_state_dict(state_dict['scheduler'])
        if self.conf.amp:
            LOG.logI(
                "Will load scaler from checkpoint since you enabled amp, make sure the checkpoint was saved with amp enabled."
            )
            try:
                self.scaler.load_state_dict(state_dict["scaler"])
            except:
                LOG.logI(
                    "checkpoint was saved without amp enabled, so use fresh GradScaler instead."
                )
                self.scaler = GradScaler()

        self.epoch = state_dict['epoch']

    def initScheduler(self):
        if isinstance(self.conf.lr_step, list):
            self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
                self.optimizer, self.conf.lr_step, self.conf.lr_factor)
        elif isinstance(self.conf.lr_step, FunctionType):
            self.scheduler = torch.optim.lr_scheduler.LambdaLR(
                self.optimizer, lr_lambda=self.conf.lr_step)
        else:
            self.scheduler = torch.optim.lr_scheduler.StepLR(
                self.optimizer, self.conf.lr_step, self.conf.lr_factor)
        LOG.logW(
            "You should reimplement initScheduler() to initialize self.scheduler, unless lr_scheduler.StepLR() or lr_scheduler.MultiStepLR() is exactly what you need"
        )

    def initTrainLoader(self):
        self.train_loader = None
        LOG.logE(
            "You must reimplement initTrainLoader() to initialize self.train_loader",
            exit=True)

    def initValLoader(self):
        self.val_loader = None
        LOG.logE(
            "You must reimplement initTrainLoader() to initialize self.val_loader",
            exit=True)

    def initOptimizer(self):
        self.initSgdOptimizer()
        LOG.logW(
            "You should reimplement initOptimizer() to initialize self.optimizer, unless SGD is exactly what you need"
        )

    def initSgdOptimizer(self):
        self.optimizer = optim.SGD(self.net.parameters(),
                                   lr=self.conf.lr,
                                   momentum=self.conf.momentum,
                                   weight_decay=self.conf.weight_decay,
                                   nesterov=self.conf.nesterov)

    def initAdamOptimizer(self):
        self.optimizer = optim.Adam(
            self.net.parameters(),
            lr=self.conf.lr,
        )
        for group in self.optimizer.param_groups:
            group.setdefault('initial_lr', group['lr'])

    def initRmspropOptimizer(self):
        self.optimizer = optim.RMSprop(
            self.net.parameters(),
            lr=self.conf.lr,
            momentum=self.conf.momentum,
            weight_decay=self.conf.weight_decay,
            # alpha=self.conf.rmsprop_alpha,
            # centered=self.conf.rmsprop_centered
        )

    def addScalar(self, tag, value, step):
        self.writer.add_scalar(tag, value, step)

    def addImage(self, tag, image, step):
        self.writer.add_image(tag, image, step)

    @syszux_once
    def addGraph(self, input):
        self.writer.add_graph(self.net, input)

    @syszux_once
    def smokeTestForExport3rd(self):
        #exportNCNN must before exportONNX
        self.exportONNX()
        self.exportNCNN()
        self.exportCoreML()
        #whether export TorchScript via trace, only here we can get self.sample
        self.exportTorchViaTrace()
        #compile pytorch state dict to TorchScript
        self.exportTorchViaScript()
        self.exportDynamicQuant()
        self.exportStaticQuant(prepare=True)

    def earlyIter(self):
        start = time.time()
        self.sample = self.sample.to(self.device)
        self.target = self.target.to(self.device)
        if not self.is_train:
            return
        self.data_cpu2gpu_time.update(time.time() - start)
        try:
            self.addGraph(self.sample)
        except:
            LOG.logW(
                "Tensorboard addGraph failed. You network foward may have more than one parameters?"
            )
            LOG.logW("Seems you need reimplement preIter function.")

    def preIter(self):
        pass

    def postIter(self):
        pass

    def preEpoch(self):
        pass

    def postEpoch(self):
        pass

    def doForward(self):
        self.output = self.net(self.sample)

    def doCalibrate(self):
        if self.static_quantized_net_prepared is None:
            return
        self.static_quantized_net_prepared(self.sample)

    def doLoss(self):
        self.loss = self.criterion(self.output, self.target)

    def doBackward(self):
        if self.conf.amp:
            self.scaler.scale(self.loss).backward()
        else:
            self.loss.backward()

    def doOptimize(self):
        if self.iter % self.conf.nominal_batch_factor != 0:
            return
        if self.conf.amp:
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            self.optimizer.step()
        self.optimizer.zero_grad()

    def doLog(self):
        if self.step % self.conf.log_every != 0:
            return
        self.addScalar('{}/Loss'.format(self.phase), self.loss.item(),
                       self.iter)
        self.addScalar('{}/LoadDataTime(secs/batch)'.format(self.phase),
                       self.load_data_time.val, self.iter)
        self.addScalar('{}/DataCpu2GpuTime(secs/batch)'.format(self.phase),
                       self.data_cpu2gpu_time.val, self.iter)
        self.addScalar('{}/TrainTime(secs/batch)'.format(self.phase),
                       self.train_time.val, self.iter)
        LOG.logI('{}: [{}][{}/{}] [Loss:{}  Lr:{}]'.format(
            self.phase, self.epoch, self.step, self.loader_len,
            self.loss.item(), self.optimizer.param_groups[0]['lr']))

    def saveState(self, current_time):
        file_partial_name = '{}__acc_{}__epoch_{}__step_{}__lr_{}'.format(
            current_time, self.accuracy, self.epoch, self.step,
            self.optimizer.param_groups[0]['lr'])
        state_file = '{}/model__{}.pth'.format(self.output_dir,
                                               file_partial_name)
        checkpoint_file = '{}/checkpoint__{}.pth'.format(
            self.output_dir, file_partial_name)
        output_trace_file = '{}/trace__{}.pt'.format(self.output_dir,
                                                     file_partial_name)
        output_script_file = '{}/script__{}.pt'.format(self.output_dir,
                                                       file_partial_name)
        output_onnx_file = '{}/onnx__{}.onnx'.format(self.output_dir,
                                                     file_partial_name)
        output_ncnn_file = '{}/ncnn__{}.bin'.format(self.output_dir,
                                                    file_partial_name)
        output_coreml_file = '{}/coreml__{}.mlmodel'.format(
            self.output_dir, file_partial_name)
        output_dynamic_quant_file = '{}/squant__{}.pt'.format(
            self.output_dir, file_partial_name)
        output_static_quant_file = '{}/dquant__{}.pt'.format(
            self.output_dir, file_partial_name)
        output_qat_file = '{}/qat__{}.pt'.format(self.output_dir,
                                                 file_partial_name)
        #save state_dict
        torch.save(self.net.state_dict(), state_file)
        #save checkpoint
        torch.save(
            {
                'optimizer': self.optimizer.state_dict(),
                'epoch': self.epoch,
                'scheduler':
                self.scheduler.state_dict() if self.scheduler else None,
                'scaler': self.scaler.state_dict() if self.conf.amp else None
            }, checkpoint_file)

        #convert for quantize, must before trace and script!!!
        self.exportDynamicQuant(output_dynamic_quant_file)
        self.exportStaticQuant(output_quant_file=output_static_quant_file)
        self.exportQAT(output_quant_file=output_qat_file)
        #save pt via trace
        self.exportTorchViaTrace(self.sample, output_trace_file)
        #save pt vida script
        self.exportTorchViaScript(output_script_file)
        #save onnx
        self.exportONNX(output_onnx_file)
        #save ncnn
        self.exportNCNN(output_ncnn_file)
        #save coreml
        self.exportCoreML(output_coreml_file)
        #tensorboard
        self.addScalar('{}/Accuracy'.format(self.phase), self.accuracy,
                       self.iter)

    def processTrain(self):
        self.setTrainContext()
        self.step = 0
        LOG.logI('Phase {} started...'.format(self.phase))
        self.loader_len = len(self.loader)
        save_every = self.loader_len // self.conf.save_num
        save_list = list(range(0, self.loader_len + 1, save_every))
        self.save_list = save_list[1:-1]
        LOG.logI('Model will be saved on step {} and the epoch end.'.format(
            self.save_list))
        self.addScalar('{}/LR'.format(self.phase),
                       self.optimizer.param_groups[0]['lr'], self.epoch)
        self.preEpoch()
        self.train_time.reset()
        self.load_data_time.reset()
        self.data_cpu2gpu_time.reset()

        start = time.time()
        for i, (sample, target) in enumerate(self.loader):
            self.load_data_time.update(time.time() - start)
            self.step = i
            self.target = target
            self.sample = sample
            self.preIter()
            self.earlyIter()
            with autocast(enabled=self.conf.amp if self.conf.amp else False):
                self.doForward()
                self.doLoss()
            self.doBackward()
            self.doOptimize()
            self.doLog()
            self.postIter()
            self.iter += 1
            self.train_time.update(time.time() - start)
            if self.step in self.save_list:
                self.processVal()
                self.setTrainContext()
            start = time.time()

        self.addScalar('{}/TrainTime(hours/epoch)'.format(self.phase),
                       round(self.train_time.sum / 3600, 2), self.epoch)
        self.addScalar(
            '{}/AverageBatchTrainTime(secs/epoch)'.format(self.phase),
            self.train_time.avg, self.epoch)
        self.addScalar(
            '{}/AverageBatchLoadDataTime(secs/epoch)'.format(self.phase),
            self.load_data_time.avg, self.epoch)
        self.addScalar(
            '{}/AverageBatchDataCpu2GpuTime(secs/epoch)'.format(self.phase),
            self.data_cpu2gpu_time.avg, self.epoch)

        self.postEpoch()
        if self.scheduler:
            self.scheduler.step()

    def processVal(self, smoke=False):
        self.setValContext()
        LOG.logI('Phase {} started...'.format(self.phase))
        with torch.no_grad():
            self.preEpoch()
            for i, (sample, target) in enumerate(self.loader):
                self.target = target
                self.sample = sample
                self.preIter()
                self.earlyIter()
                self.doForward()
                #calibrate only for quantization.
                self.doCalibrate()
                self.doLoss()
                self.smokeTestForExport3rd()
                LOG.logI('{}: [{}][{}/{}]'.format(self.phase, self.epoch, i,
                                                  len(self.loader)))
                self.postIter()
                if smoke:
                    break
            self.postEpoch()
        self.saveState(self.getTime())

    def processAccept(self):
        self.setValContext()

    def process(self):
        self.auditConfig()
        self.iter = 0
        epoch_start = self.epoch
        self.processVal(smoke=True)
        self.optimizer.zero_grad()
        for epoch in range(epoch_start, self.conf.epoch_num):
            self.epoch = epoch
            LOG.logI('Epoch {} started...'.format(self.epoch))
            self.processTrain()
            self.processVal()
            self.processAccept()

    def __call__(self):
        self.process()
예제 #10
0
class CustomMTSAC(MTSAC):
    def __init__(
        self,
        policy,
        qf1,
        qf2,
        replay_buffer,
        env_spec,
        sampler,
        train_task_sampler,
        *,
        num_tasks,
        gradient_steps_per_itr,
        task_update_frequency=1,
        max_episode_length_eval=None,
        fixed_alpha=None,
        target_entropy=None,
        initial_log_entropy=0.,
        discount=0.99,
        buffer_batch_size=64,
        min_buffer_size=10000,
        target_update_tau=5e-3,
        policy_lr=3e-4,
        qf_lr=3e-4,
        reward_scale=1.0,
        optimizer=torch.optim.Adam,
        num_evaluation_episodes=5,
        # added
        fp16=False,
        log_per_task=False,
        share_train_eval_env=False
    ):

        super().__init__(
            policy=policy,
            qf1=qf1,
            qf2=qf2,
            replay_buffer=replay_buffer,
            env_spec=env_spec,
            sampler=sampler,
            test_sampler=sampler,  # not used, for compatibility
            train_task_sampler=train_task_sampler,
            num_tasks=num_tasks,
            gradient_steps_per_itr=gradient_steps_per_itr,
            max_episode_length_eval=max_episode_length_eval,
            fixed_alpha=fixed_alpha,
            target_entropy=target_entropy,
            initial_log_entropy=initial_log_entropy,
            discount=discount,
            buffer_batch_size=buffer_batch_size,
            min_buffer_size=min_buffer_size,
            target_update_tau=target_update_tau,
            policy_lr=policy_lr,
            qf_lr=qf_lr,
            reward_scale=reward_scale,
            optimizer=optimizer,
            steps_per_epoch=1,
            num_evaluation_episodes=num_evaluation_episodes,
        )
        self._train_task_sampler = train_task_sampler
        self._task_update_frequency = task_update_frequency
        self._fp16 = fp16
        self._log_per_task = log_per_task
        self._total_envsteps = 0

        # scalers for fp16
        # TODO: don't initialize gradscalers if not using fp16
        # Also don't save and/or restore
        self._gs_qf1 = GradScaler()
        self._gs_qf2 = GradScaler()
        self._gs_policy = GradScaler()
        self._gs_alpha = GradScaler()

        # get updates for evaluation
        self.eval_env_updates = self.resample_environment(force_update=True)
        self.share_train_eval_env = share_train_eval_env
        if self.share_train_eval_env:
            logging.warn("WARNING: Sharing train and eval environments")

        # Fix bug with alpha with optimizer
        self._use_automatic_entropy_tuning = fixed_alpha is None
        if self._use_automatic_entropy_tuning:
            self._alpha_optimizer = optimizer([self._log_alpha], lr=self._policy_lr)

    def state_dict(self):
        return {
            # parameters
            "policy": self.policy.state_dict(),
            "qf1": self._qf1.state_dict(),
            "qf2": self._qf2.state_dict(),
            "target_qf1": self._target_qf1.state_dict(),
            "target_qf2": self._target_qf2.state_dict(),
            "log_alpha": self._log_alpha,

            # scalers
            "gs_qf1": self._gs_qf1.state_dict(),
            "gs_qf2": self._gs_qf2.state_dict(),
            "gs_policy": self._gs_policy.state_dict(),
            "gs_alpha": self._gs_alpha.state_dict(),

            # optimizers
            "policy_optimizer": self._policy_optimizer.state_dict(),
            "qf1_optimizer": self._qf1_optimizer.state_dict(),
            "qf2_optimizer": self._qf2_optimizer.state_dict(),
            "alpha_optimizer": self._alpha_optimizer.state_dict(),

            # other variables
            "replay_buffer": self.replay_buffer,
            "total_envsteps": self._total_envsteps,
        }

    def load_env_state(self, env_state):
        self.eval_env_updates = env_state

    def load_state(self, state):
        # parameters
        self.policy.load_state_dict(state["policy"])
        self._qf1.load_state_dict(state["qf1"])
        self._qf2.load_state_dict(state["qf2"])
        self._target_qf1.load_state_dict(state["target_qf1"])
        self._target_qf2.load_state_dict(state["target_qf2"])
        self._log_alpha.data = state["log_alpha"]

        # scalers
        self._gs_qf1.load_state_dict(state["gs_qf1"])
        self._gs_qf2.load_state_dict(state["gs_qf2"])
        self._gs_policy.load_state_dict(state["gs_policy"])
        self._gs_alpha.load_state_dict(state["gs_alpha"])

        # optimizers
        self._policy_optimizer.load_state_dict(state["policy_optimizer"])
        self._qf1_optimizer.load_state_dict(state["qf1_optimizer"])
        self._qf2_optimizer.load_state_dict(state["qf2_optimizer"])
        self._alpha_optimizer.load_state_dict(state["alpha_optimizer"])

        # other variables
        self.replay_buffer = state["replay_buffer"]
        self._total_envsteps = state["total_envsteps"]

    def get_updated_policy(self, policy_hook=None):
        with torch.no_grad():
            updated_policy = copy.deepcopy(self.policy)
        updated_policy.eval()
        # attach hooks
        if policy_hook:
            policy_hook(updated_policy)

        return updated_policy

    def update_buffer(self, trajectories):
        """Update Buffer"""

        self._total_envsteps += sum(trajectories.lengths)
        path_returns = []
        for path in trajectories.to_list():
            self.replay_buffer.add_path(dict(
                observation=path["observations"],
                action=path["actions"],
                reward=path["rewards"].reshape(-1, 1),
                next_observation=path["next_observations"],
                terminal=np.array([
                    step_type == StepType.TERMINAL
                    for step_type in path["step_types"]
                ]).reshape(-1, 1)
            ))
            path_returns.append(sum(path["rewards"]))

        self.episode_rewards.append(np.mean(path_returns))

    def resample_environment(self, epoch=0, force_update=False):
        """
        TODO: fix env update in sampler

        Intended behavior:
        if epoch % self._task_update_frequency == 0 or force_update:
            return self._train_task_sampler.sample(self._num_tasks)
        """
        # TODO: remove first line to allow force update
        if epoch % self._task_update_frequency == 0 or force_update:
            return self._train_task_sampler.sample(self._num_tasks)

    def run_epoch(self, epoch, env_steps_per_epoch):
        """
        Run one epoch, which is composed of one N sample collections and N training
        steps. Each training step in their turn is composed of M gradient steps of
        batch size B

        Total number of samples used by the algorithm in a epoch is given by N * M * B
        (steps * gradient_steps * batch size)

        Samples collected are only used to update the buffer, and there is no direct
        influence on number of gradient steps or batch size.

        Returns:
            float: The average return in last epoch cycle.

        """
        t0 = time()

        env_updates = (
            self.eval_env_updates if self.share_train_eval_env
            else self.resample_environment(epoch)
        )

        new_trajectories = self._sampler.obtain_samples(
            num_samples=env_steps_per_epoch,
            agent_update=self.get_updated_policy(),
            env_updates=env_updates,
        )
        self.update_buffer(new_trajectories)
        t1 = time()
        total_losses = self.run_step()
        time_to_collect_samples = t1 - t0
        time_to_update_gradient = time() - t1

        log_dict = self._log_statistics(*total_losses)

        # TODO: switch to logger.debug once logger is fixed
        logging.warn(f"Time to collect samples: {time_to_collect_samples:.2f}")
        logging.warn(f"Time to update gradient: {time_to_update_gradient:.2f}")

        return log_dict

    def run_step(self):
        """
        Run one training step, which is composed of M gradient steps

        For M gradients steps:
        - sample a batch from buffer
        - perform one gradient step in all three networks (policy, qf1 and qf2)
        """

        total_losses = [0, 0, 0]
        for _ in range(self._gradient_steps):
            if self.replay_buffer.n_transitions_stored >= self._min_buffer_size:
                samples = as_torch_dict(self.replay_buffer.sample_transitions(
                    self._buffer_batch_size
                ))
                policy_loss, qf1_loss, qf2_loss = self.optimize_policy(samples)
                total_losses[0] += policy_loss
                total_losses[1] += qf1_loss
                total_losses[2] += qf2_loss
                self._update_targets()

        # Normalize losses by total of gradient updates
        total_losses = [loss / self._gradient_steps for loss in total_losses]

        return total_losses

    def _evaluate_policy(self, epoch, policy_hook=None):
        """Evaluate the performance of the policy via deterministic sampling.

            Statistics such as (average) discounted return and success rate are
            recorded.

        Args:
            epoch (int): The current training epoch.

        Returns:
            float: The average return across self._num_evaluation_episodes
                episodes

        """
        t0 = time()

        # Collect episodes for evaluation
        eval_trajectories, policy_hook_data = self._sampler.obtain_exact_episodes(
            n_eps_per_worker=self._num_evaluation_episodes,
            agent_update=self.get_updated_policy(policy_hook=policy_hook),
            env_updates=self.eval_env_updates,
        )

        # Log performance
        undiscounted_returns, log_dict = log_multitask_performance(
            epoch,
            batch=eval_trajectories,
            discount=self._discount,
            log_per_task=self._log_per_task
        )
        log_dict["average_return"] = np.mean(undiscounted_returns)

        logging.warn(f"Time to evaluate policy: {time()-t0:.2f}")

        return undiscounted_returns, log_dict, policy_hook_data

    def _log_statistics(self, policy_loss, qf1_loss, qf2_loss):
        """Record training statistics to dowel such as losses and returns.

        Args:
            policy_loss (torch.Tensor): loss from actor/policy network.
            qf1_loss (torch.Tensor): loss from 1st qf/critic network.
            qf2_loss (torch.Tensor): loss from 2nd qf/critic network.

        """
        log_dict = {}
        with torch.no_grad():
            log_dict["AlphaTemperature/mean"] = self._log_alpha.exp().mean().item()
        log_dict["Policy/Loss"] = policy_loss.item()
        log_dict["QF/{}".format("Qf1Loss")] = float(qf1_loss)
        log_dict["QF/{}".format("Qf2Loss")] = float(qf2_loss)
        log_dict["ReplayBuffer/buffer_size"] = self.replay_buffer.n_transitions_stored
        log_dict["Average/TrainAverageReturn"] = np.mean(self.episode_rewards)
        log_dict["TotalEnvSteps"] = self._total_envsteps

        return log_dict

    def _get_log_alpha(self, samples_data):
        """Return the value of log_alpha.
        Args:
            samples_data (dict): Transitions(S,A,R,S') that are sampled from
                the replay buffer. It should have the keys 'observation',
                'action', 'reward', 'terminal', and 'next_observations'.
        Note:
            samples_data's entries should be torch.Tensor's with the following
            shapes:
                observation: :math:`(N, O^*)`
                action: :math:`(N, A^*)`
                reward: :math:`(N, 1)`
                terminal: :math:`(N, 1)`
                next_observation: :math:`(N, O^*)`
        Raises:
            ValueError: If the number of tasks, num_tasks passed to
                this algorithm doesn't match the length of the task
                one-hot id in the observation vector.
        Returns:
            torch.Tensor: log_alpha. shape is (1, self.buffer_batch_size)
        """
        obs = samples_data["observation"]
        log_alpha = self._log_alpha
        one_hots = obs[:, -self._num_tasks:]

        if (log_alpha.shape[0] != one_hots.shape[1]
                or one_hots.shape[1] != self._num_tasks
                or log_alpha.shape[0] != self._num_tasks):
            raise ValueError(
                "The number of tasks in the environment does "
                "not match self._num_tasks. Are you sure that you passed "
                "The correct number of tasks?")

        with autocast(enabled=self._fp16):
            return torch.mm(one_hots, log_alpha.unsqueeze(0).t()).squeeze()

    def _temperature_objective(self, log_pi, samples_data):
        """Compute the temperature/alpha coefficient loss.
        Args:
            log_pi(torch.Tensor): log probability of actions that are sampled
                from the replay buffer. Shape is (1, buffer_batch_size).
            samples_data (dict): Transitions(S,A,R,S') that are sampled from
                the replay buffer. It should have the keys 'observation',
                'action', 'reward', 'terminal', and 'next_observations'.
        Note:
            samples_data's entries should be torch.Tensor's with the following
            shapes:
                observation: :math:`(N, O^*)`
                action: :math:`(N, A^*)`
                reward: :math:`(N, 1)`
                terminal: :math:`(N, 1)`
                next_observation: :math:`(N, O^*)`
        Returns:
            torch.Tensor: the temperature/alpha coefficient loss.
        """
        alpha_loss = 0

        with autocast(enabled=self._fp16):
            if self._use_automatic_entropy_tuning:
                alpha_loss = (-(self._get_log_alpha(samples_data)) *
                            (log_pi.detach() + self._target_entropy)).mean()

            return alpha_loss

    def _actor_objective(self, samples_data, new_actions, log_pi_new_actions):
        """Compute the Policy/Actor loss.
        Args:
            samples_data (dict): Transitions(S,A,R,S') that are sampled from
                the replay buffer. It should have the keys 'observation',
                'action', 'reward', 'terminal', and 'next_observations'.
            new_actions (torch.Tensor): Actions resampled from the policy based
                based on the Observations, obs, which were sampled from the
                replay buffer. Shape is (action_dim, buffer_batch_size).
            log_pi_new_actions (torch.Tensor): Log probability of the new
                actions on the TanhNormal distributions that they were sampled
                from. Shape is (1, buffer_batch_size).
        Note:
            samples_data's entries should be torch.Tensor's with the following
            shapes:
                observation: :math:`(N, O^*)`
                action: :math:`(N, A^*)`
                reward: :math:`(N, 1)`
                terminal: :math:`(N, 1)`
                next_observation: :math:`(N, O^*)`
        Returns:
            torch.Tensor: loss from the Policy/Actor.
        """
        obs = samples_data["observation"]

        with torch.no_grad():
            alpha = self._get_log_alpha(samples_data).exp()

        with autocast(enabled=self._fp16):
            min_q_new_actions = torch.min(self._qf1(obs, new_actions),
                                          self._qf2(obs, new_actions))

            policy_objective = ((alpha * log_pi_new_actions) -
                                min_q_new_actions.flatten()).mean()

            return policy_objective

    def _critic_objective(self, samples_data):
        """Compute the Q-function/critic loss.
        Args:
            samples_data (dict): Transitions(S,A,R,S') that are sampled from
                the replay buffer. It should have the keys 'observation',
                'action', 'reward', 'terminal', and 'next_observations'.
        Note:
            samples_data's entries should be torch.Tensor's with the following
            shapes:
                observation: :math:`(N, O^*)`
                action: :math:`(N, A^*)`
                reward: :math:`(N, 1)`
                terminal: :math:`(N, 1)`
                next_observation: :math:`(N, O^*)`
        Returns:
            torch.Tensor: loss from 1st q-function after optimization.
            torch.Tensor: loss from 2nd q-function after optimization.
        """
        obs = samples_data["observation"]
        actions = samples_data["action"]
        rewards = samples_data["reward"].flatten()
        terminals = samples_data["terminal"].flatten()
        next_obs = samples_data["next_observation"]

        with torch.no_grad():
            alpha = self._get_log_alpha(samples_data).exp()

        with autocast(enabled=self._fp16):
            q1_pred = self._qf1(obs, actions)
            q2_pred = self._qf2(obs, actions)

            new_next_actions_dist = self.policy(next_obs)[0]
            new_next_actions_pre_tanh, new_next_actions = (
                new_next_actions_dist.rsample_with_pre_tanh_value())
            new_log_pi = new_next_actions_dist.log_prob(
                value=new_next_actions,
                pre_tanh_value=new_next_actions_pre_tanh
            )

            target_q_values = torch.min(
                self._target_qf1(next_obs, new_next_actions),
                self._target_qf2(next_obs, new_next_actions)
            ).flatten() - (alpha * new_log_pi)

            with torch.no_grad():
                q_target = rewards * self._reward_scale + (
                    1. - terminals) * self._discount * target_q_values

            qf1_loss = F.mse_loss(q1_pred.flatten(), q_target)
            qf2_loss = F.mse_loss(q2_pred.flatten(), q_target)

            return qf1_loss, qf2_loss

    def optimize_policy(self, samples_data):
        """Optimize the policy q_functions, and temperature coefficient. Rezero
        model weights (if applicable) after each optimizer step.

        Args:
            samples_data (dict): Transitions(S,A,R,S') that are sampled from
                the replay buffer. It should have the keys 'observation',
                'action', 'reward', 'terminal', and 'next_observations'.

        Note:
            samples_data's entries should be torch.Tensor's with the following
            shapes:
                observation: :math:`(N, O^*)`
                action: :math:`(N, A^*)`
                reward: :math:`(N, 1)`
                terminal: :math:`(N, 1)`
                next_observation: :math:`(N, O^*)`

        Returns:
            torch.Tensor: loss from actor/policy network after optimization.
            torch.Tensor: loss from 1st q-function after optimization.
            torch.Tensor: loss from 2nd q-function after optimization.

        """
        if self._fp16:
            return self.optimize_policy_with_autocast(samples_data)

        obs = samples_data["observation"]
        qf1_loss, qf2_loss = self._critic_objective(samples_data)

        self._qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self._qf1_optimizer.step()
        self._qf1.apply(rezero_weights)

        self._qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self._qf2_optimizer.step()
        self._qf2.apply(rezero_weights)

        action_dists = self.policy(obs)[0]
        new_actions_pre_tanh, new_actions = (
            action_dists.rsample_with_pre_tanh_value())
        log_pi_new_actions = action_dists.log_prob(
            value=new_actions, pre_tanh_value=new_actions_pre_tanh)

        policy_loss = self._actor_objective(samples_data, new_actions,
                                            log_pi_new_actions)
        self._policy_optimizer.zero_grad()
        policy_loss.backward()

        self._policy_optimizer.step()
        self.policy.apply(rezero_weights)

        if self._use_automatic_entropy_tuning:
            alpha_loss = self._temperature_objective(log_pi_new_actions,
                                                     samples_data)
            self._alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self._alpha_optimizer.step()

        return policy_loss, qf1_loss, qf2_loss

    def optimize_policy_with_autocast(self, samples_data):
        """Optimize the policy q_functions, and temperature coefficient. Rezero
        model weights (if applicable) after each optimizer step.

        Args:
            samples_data (dict): Transitions(S,A,R,S') that are sampled from
                the replay buffer. It should have the keys 'observation',
                'action', 'reward', 'terminal', and 'next_observations'.

        Note:
            samples_data's entries should be torch.Tensor's with the following
            shapes:
                observation: :math:`(N, O^*)`
                action: :math:`(N, A^*)`
                reward: :math:`(N, 1)`
                terminal: :math:`(N, 1)`
                next_observation: :math:`(N, O^*)`

        Returns:
            torch.Tensor: loss from actor/policy network after optimization.
            torch.Tensor: loss from 1st q-function after optimization.
            torch.Tensor: loss from 2nd q-function after optimization.

        """
        obs = samples_data["observation"]

        qf1_loss, qf2_loss = self._critic_objective(samples_data)

        self._qf1_optimizer.zero_grad()
        self._gs_qf1.scale(qf1_loss).backward()
        self._gs_qf1.step(self._qf1_optimizer)
        self._gs_qf1.update()
        self._qf1.apply(rezero_weights)

        self._qf2_optimizer.zero_grad()
        self._gs_qf2.scale(qf2_loss).backward()
        self._gs_qf2.step(self._qf2_optimizer)
        self._gs_qf2.update()
        self._qf2.apply(rezero_weights)

        with autocast():
            action_dists = self.policy(obs)[0]
            new_actions_pre_tanh, new_actions = (
                action_dists.rsample_with_pre_tanh_value()
            )
            log_pi_new_actions = action_dists.log_prob(
                value=new_actions, pre_tanh_value=new_actions_pre_tanh)

        policy_loss = self._actor_objective(samples_data, new_actions,
                                            log_pi_new_actions)

        self._policy_optimizer.zero_grad()
        self._gs_policy.scale(policy_loss).backward()
        self._gs_policy.step(self._policy_optimizer)
        self._gs_policy.update()
        self.policy.apply(rezero_weights)

        if self._use_automatic_entropy_tuning:
            alpha_loss = self._temperature_objective(log_pi_new_actions,
                                                     samples_data)

            self._alpha_optimizer.zero_grad()
            self._gs_alpha.scale(alpha_loss).backward()
            self._gs_alpha.step(self._alpha_optimizer)
            self._gs_alpha.update()

        return policy_loss, qf1_loss, qf2_loss

    def shutdown_worker(self):
        """Shutdown Plotter and Sampler workers."""
        self._sampler.shutdown_worker()
예제 #11
0
def prepare_optimizers(args, model, checkpoint, global_steps):
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']

    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    if args.lr_decay == 'poly':
        Scheduler = PolyWarmUpScheduler
    elif args.lr_decay == 'linear':
        Scheduler = LinearWarmUpScheduler
    else:
        raise ValueError('Unknown lr decay "{}"'.format(args.lr_decay))

    optimizer = FusedLAMB(optimizer_grouped_parameters, lr=args.learning_rate)

    if checkpoint is not None:
        if args.resume_step >= args.previous_phase_end_step:
            keys = list(checkpoint['optimizer']['state'].keys())
            # Override hyperparameters from previous checkpoint
            for key in keys:
                checkpoint['optimizer']['state'][key]['step'] = global_steps
            for i, item in enumerate(checkpoint['optimizer']['param_groups']):
                checkpoint['optimizer']['param_groups'][i][
                    'step'] = global_steps
                checkpoint['optimizer']['param_groups'][i][
                    't_total'] = args.max_steps
                checkpoint['optimizer']['param_groups'][i][
                    'warmup'] = args.warmup_proportion
                checkpoint['optimizer']['param_groups'][i][
                    'lr'] = args.learning_rate
        optimizer.load_state_dict(checkpoint['optimizer'])

    lr_schedulers = [
        Scheduler(optimizer,
                  warmup=args.warmup_proportion,
                  total_steps=args.max_steps)
    ]

    scaler = None
    if args.fp16:
        scaler = GradScaler()
        if checkpoint is not None and 'scaler' in checkpoint:
            scaler.load_state_dict(checkpoint['scaler'])

    preconditioner = None
    if args.kfac:
        preconditioner = kfac.KFAC(
            model,
            lr=args.learning_rate,
            factor_decay=args.kfac_stat_decay,
            damping=args.kfac_damping,
            kl_clip=args.kfac_kl_clip,
            factor_update_freq=args.kfac_factor_interval,
            inv_update_freq=args.kfac_inv_interval,
            # Skip TrainingHeads which contains the decoder, a Linear module
            # with shape (seq_len, vocab_size), such that it is too large to invert
            skip_layers=args.kfac_skip_layers,
            # BERT calls KFAC very infrequently so no need to optimize for
            # communication. Optimize for memory instead.
            comm_method=kfac.CommMethod.HYBRID_OPT,
            grad_worker_fraction=0.5,
            inv_dtype=torch.float16,
            # Compute the factors and update the running averages during the
            # forward backward pass b/c we are using grad accumulation but
            # not accumulating the input/output data
            accumulate_data=False,
            compute_factor_in_hook=True,
            distribute_layer_factors=False,
            grad_scaler=scaler,
        )

        lrs = Scheduler(preconditioner,
                        warmup=args.warmup_proportion,
                        total_steps=args.max_steps)
        lr_schedulers.append(lrs)

        if checkpoint is not None and 'preconditioner' in checkpoint:
            preconditioner.load_state_dict(checkpoint['preconditioner'])

        if is_main_process():
            logger.info(preconditioner)

    return optimizer, preconditioner, lr_schedulers, scaler
예제 #12
0
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers)
    valloader = DataLoader(valdataset,
                           batch_size=args.batch_size,
                           shuffle=False,
                           num_workers=args.num_workers)

    scaler = GradScaler()

    if args.resume:
        ckpt = torch.load(os.path.join(args.save_dir, 'recorder_2.pt'))
        model.load_state_dict(ckpt['model'])
        optimizer.load_state_dict(ckpt['optimizer'])
        scheduler.load_state_dict(ckpt['scheduler'])
        scaler.load_state_dict(ckpt['scaler'])

    if args.resume:
        best_loss = scheduler.best
    else:
        best_loss = np.inf

    save_recorder = 5

    for epoch in range(args.epochs):

        print(f'Epoch {epoch+1}/{args.epochs}')

        train_loss, train_acc = train_one_epoch(trainloader, model, criterion,
                                                optimizer, scaler, device,
                                                args, epoch)
예제 #13
0
class Trainer:
    """Model trainer

    Args:
        model: model to train
        loss_fn: loss function
        optimizer: model optimizer
        generator: pretrained generator
        projector: pretrained projector
        device: device to train the model on
        batch_size: number of batch elements
        iterations: number of iterations
        scheduler: learning rate scheduler
        grad_clip_max_norm: gradient clipping max norm (disabled if None)
        writer: writer which logs metrics to TensorBoard (disabled if None)
        save_path: folder in which to save models (disabled if None)
        checkpoint_path: path to model checkpoint, to resume training
        mixed_precision: enable mixed precision training

    """
    def __init__(
        self,
        model: torch.nn.Module,
        loss_fn: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        generator: Generator,
        projector: torch.nn.Module,
        batch_size: int,
        iterations: int,
        device: torch.device,
        eval_freq: int = 1000,
        eval_iters: int = 100,
        scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
        grad_clip_max_norm: Optional[float] = None,
        writer: Optional[SummaryWriter] = None,
        save_path: Optional[str] = None,
        checkpoint_path: Optional[str] = None,
        mixed_precision: bool = False,
        train_projector: bool = True,
        feed_layers: Optional[List[int]] = None,
    ) -> None:

        # Logging
        self.logger = logging.getLogger()
        self.writer = writer

        # Saving
        self.save_path = save_path

        # Device
        self.device = device

        # Model
        self.model = model
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.generator = generator
        self.projector = projector
        self.train_projector = train_projector
        self.feed_layers = feed_layers

        #  Eval
        self.eval_freq = eval_freq
        self.eval_iters = eval_iters

        # Scheduler
        self.scheduler = scheduler
        self.grad_clip_max_norm = grad_clip_max_norm

        # Batch & Iteration
        self.batch_size = batch_size
        self.iterations = iterations
        self.start_iteration = 0

        # Floating-point precision
        self.mixed_precision = (True if self.device.type == "cuda"
                                and mixed_precision else False)
        self.scaler = GradScaler() if self.mixed_precision else None

        if checkpoint_path:
            self._load_from_checkpoint(checkpoint_path)

        # Metrics
        self.train_acc_metric = LossMetric()
        self.train_loss_metric = LossMetric()

        self.val_acc_metric = LossMetric()
        self.val_loss_metric = LossMetric()

        # Best
        self.best_loss = -1

    def train(self) -> None:
        """Trains the model"""
        self.logger.info("Beginning training")
        start_time = time.time()

        epoch = 0
        iteration = self.start_iteration
        while iteration < self.iterations:
            if iteration + self.eval_freq < self.iterations:
                num_iters = self.eval_freq
            else:
                num_iters = self.iterations - iteration

            start_epoch_time = time.time()
            if self.mixed_precision:
                self._train_loop_amp(epoch, num_iters)
            else:
                self._train_loop(epoch, num_iters)

            self._val_loop(epoch, self.eval_iters)

            epoch_time = time.time() - start_epoch_time
            self._end_loop(epoch, epoch_time, iteration)

            iteration += num_iters
            epoch += 1

        train_time_h = (time.time() - start_time) / 3600
        self.logger.info(f"Finished training! Total time: {train_time_h:.2f}h")
        self._save_model(os.path.join(self.save_path, "final_model.pt"),
                         self.iterations)

    def _train_loop(self, epoch: int, iterations: int) -> None:
        """
        Regular train loop

        Args:
            epoch: current epoch
            iterations: iterations to run model
        """
        # Progress bar
        pbar = tqdm.tqdm(total=iterations, leave=False)
        pbar.set_description(f"Epoch {epoch} | Train")

        # Set to train
        self.model.train()

        # Set to eval
        self.generator.eval()

        if self.train_projector:
            self.projector.train()
        else:
            self.projector.eval()

        for i in range(iterations):
            # To device
            z = self.generator.sample_latent(self.batch_size)
            z = z.to(self.device)
            z_orig = z

            # Original features
            with torch.no_grad():
                orig_feats = self.generator.get_features(z)
                orig_feats = self.projector(orig_feats)

            # Apply Directions
            self.optimizer.zero_grad()
            z = self.model(z)

            # Forward
            features = []
            for j in range(z.shape[0] // self.batch_size):
                # Prepare batch
                start, end = j * self.batch_size, (j + 1) * self.batch_size
                z_batch = z[start:end, ...]

                # Manipulate only asked layers
                if self.feed_layers is not None:
                    n_latent = self.generator.n_latent()

                    z_batch_layers = []
                    for i in range(n_latent):
                        if i in self.feed_layers:
                            z_batch_layers.append(z_batch)
                        else:
                            z_batch_layers.append(z_orig)
                    z_batch = z_batch_layers

                # Get features
                feats = self.generator.get_features(z_batch)
                feats = self.projector(feats)

                # Take feature divergence
                feats = feats - orig_feats
                feats = feats / torch.reshape(torch.norm(feats, dim=1),
                                              (-1, 1))

                features.append(feats)
            features = torch.cat(features, dim=0)

            # Loss
            acc, loss = self.loss_fn(features)
            loss.backward()

            if self.grad_clip_max_norm is not None:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.grad_clip_max_norm)

            self.optimizer.step()
            self.scheduler.step()

            # Update metrics
            self.train_acc_metric.update(acc.item(), z.shape[0])
            self.train_loss_metric.update(loss.item(), z.shape[0])

            # Update progress bar
            pbar.update()
            pbar.set_postfix_str(
                f"Acc: {acc.item():.3f} Loss: {loss.item():.3f}",
                refresh=False)

        pbar.close()

    def _train_loop_amp(self, epoch: int, iterations: int) -> None:
        """
        Train loop with Automatic Mixed Precision

        Args:
            epoch: current epoch
            iterations: iterations to run model
        """
        # Progress bar
        pbar = tqdm.tqdm(total=len(iterations), leave=False)
        pbar.set_description(f"Epoch {epoch} | Train")

        # Set to train
        self.model.train()

        # Loop
        for i in range(iterations):
            # To device
            z = self.generator.sample_latent(self.batch_size)
            z = z.to(self.device)

            # Forward + backward
            self.optimizer.zero_grad()

            # Use amp in forward pass
            with autocast():
                # Original features
                with torch.no_grad():
                    orig_feats = self.generator.get_features(z)
                    orig_feats = self.projector(orig_feats)

                # Apply Directions
                z = self.model(z)

                # Forward
                features = []
                for j in range(z.shape[0] // self.batch_size):
                    # Prepare batch
                    start, end = j * self.batch_size, (j + 1) * self.batch_size

                    # Get features
                    feats = self.generator.get_features(z[start:end, ...])
                    feats = self.projector(feats)

                    # Take feature divergence
                    feats = feats - orig_feats
                    feats = feats / torch.reshape(torch.norm(feats, dim=1),
                                                  (-1, 1))

                    features.append(feats)
                features = torch.cat(features, dim=0)

                # Loss
                acc, loss = self.loss_fn(features)

            # Backward pass with scaler
            self.scaler.scale(loss).backward()

            # Unscale before gradient clipping
            self.scaler.unscale_(self.optimizer)

            if self.grad_clip_max_norm is not None:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.grad_clip_max_norm)

            # Update optimizer and scaler
            self.scaler.step(self.optimizer)
            self.scaler.update()

            self.scheduler.step()

            # Update metrics
            self.train_acc_metric.update(acc.item(), z.shape[0])
            self.train_loss_metric.update(loss.item(), z.shape[0])

            # Update progress bar
            pbar.update()
            pbar.set_postfix_str(
                f"Acc: {acc.item():.3f} Loss: {loss.item():.3f}",
                refresh=False)

        pbar.close()

    def _val_loop(self, epoch: int, iterations: int) -> None:
        """
        Standard validation loop

        Args:
            epoch: current epoch
            iterations: iterations to run model
        """
        # Progress bar
        pbar = tqdm.tqdm(total=iterations, leave=False)
        pbar.set_description(f"Epoch {epoch} | Validation")

        # Set to eval
        self.model.eval()
        self.generator.eval()
        self.projector.eval()

        # Loop
        for i in range(iterations):
            with torch.no_grad():
                # To device
                z = self.generator.sample_latent(self.batch_size)
                z = z.to(self.device)

                # Original features
                orig_feats = self.generator.get_features(z)
                orig_feats = self.projector(orig_feats)

                # Apply Directions
                z = self.model(z)

                # Forward
                features = []
                for j in range(z.shape[0] // self.batch_size):
                    # Prepare batch
                    start, end = j * self.batch_size, (j + 1) * self.batch_size

                    # Get features
                    feats = self.generator.get_features(z[start:end, ...])
                    feats = self.projector(feats)

                    # Take feature divergence
                    feats = feats - orig_feats
                    feats = feats / torch.reshape(torch.norm(feats, dim=1),
                                                  (-1, 1))

                    features.append(feats)
                features = torch.cat(features, dim=0)

                # Loss
                acc, loss = self.loss_fn(features)
                self.val_acc_metric.update(acc.item(), z.shape[0])
                self.val_loss_metric.update(loss.item(), z.shape[0])

                # Update progress bar
                pbar.update()
                pbar.set_postfix_str(
                    f"Acc: {acc.item():.3f} Loss: {loss.item():.3f}",
                    refresh=False)

        pbar.close()

    def _end_loop(self, epoch: int, epoch_time: float, iteration: int):
        # Print epoch results
        self.logger.info(self._epoch_str(epoch, epoch_time))

        # Write to tensorboard
        if self.writer is not None:
            self._write_to_tb(epoch)

        # Save model
        if self.save_path is not None:
            self._save_model(os.path.join(self.save_path, "most_recent.pt"),
                             iteration)

        eval_loss = self.val_loss_metric.compute()
        if self.best_loss == -1 or eval_loss < self.best_loss:
            self.best_loss = eval_loss
            self._save_model(os.path.join(self.save_path, "best_model.pt"),
                             iteration)

        # Clear metrics
        self.train_loss_metric.reset()
        self.train_acc_metric.reset()
        self.val_loss_metric.reset()
        self.val_acc_metric.reset()

    def _epoch_str(self, epoch: int, epoch_time: float):
        s = f"Epoch {epoch} "
        s += f"| Train acc: {self.train_acc_metric.compute():.3f} "
        s += f"| Train loss: {self.train_loss_metric.compute():.3f} "
        s += f"| Val acc: {self.val_acc_metric.compute():.3f} "
        s += f"| Val loss: {self.val_loss_metric.compute():.3f} "
        s += f"| Epoch time: {epoch_time:.1f}s"

        return s

    def _write_to_tb(self, iteration):
        self.writer.add_scalar("Loss/train", self.train_loss_metric.compute(),
                               iteration)
        self.writer.add_scalar("Acc/train", self.train_acc_metric.compute(),
                               iteration)
        self.writer.add_scalar("Loss/val", self.val_loss_metric.compute(),
                               iteration)
        self.writer.add_scalar("Acc/val", self.val_acc_metric.compute(),
                               iteration)

    def _save_model(self, path, iteration):
        obj = {
            "iteration":
            iteration + 1,
            "optimizer":
            self.optimizer.state_dict(),
            "model":
            self.model.state_dict(),
            "projector":
            self.projector.state_dict(),
            "scheduler":
            self.scheduler.state_dict()
            if self.scheduler is not None else None,
            "scaler":
            self.scaler.state_dict() if self.mixed_precision else None,
        }
        torch.save(obj, os.path.join(self.save_path, path))

    def _load_from_checkpoint(self, checkpoint_path: str) -> None:
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        self.model.load_state_dict(checkpoint["model"])
        self.projector.load_state_dict(checkpoint["projector"])
        self.optimizer.load_state_dict(checkpoint["optimizer"])

        self.start_iteration = checkpoint["iteration"]

        if self.scheduler:
            self.scheduler.load_state_dict(checkpoint["scheduler"])

        if self.mixed_precision and "scaler" in checkpoint:
            self.scaler.load_state_dict(checkpoint["scheduler"])

        if self.start_iteration > self.iterations:
            raise ValueError(
                "Starting iteration is larger than total iterations")

        self.logger.info(
            f"Checkpoint loaded, resuming from iteration {self.start_iteration}"
        )
예제 #14
0
class ClassificationTask(ClassyTask):
    """Basic classification training task.

    This task encapsultates all of the components and steps needed to
    train a classifier using a :class:`classy_vision.trainer.ClassyTrainer`.

    Assumes a train / test phase per each epoch and that the datasets
    have the same API as the map-style Dataset class in
    `torch.utils.data.dataset <https://pytorch.org/docs/stable/data.html
    #torch.utils.data.Dataset>`_ (in particular, this task makes use of
    the len).  If you are using an `IterableDataset <https://pytorch.org/docs/
    stable/data.html#torch.utils.data.IterableDataset>`_ then a custom task
    may be appropriate.


    :var loss: Loss (see :class:`classy_vision.losses.ClassyLoss`) function used
        for computing the loss in each forward pass
    :var datasets: Mapping from a ``phase_type`` in ["train", "test']
        to dataset used for training (or testing)
    :var meters: List of meters (see :class:`classy_vision.meters.ClassyMeter`)
        to calculate during training
    :var num_epochs: Number of epochs (passes over dataset) to train
    :var test_only: Used to only run the test phase
    :var base_model: Model to be trained, unwrapped in DDP or DP wrappers
    :var optimizer: Optimizer used in train step
    :var optimizer_schedulers: Dictionary. Key is the name of the optimizer
        option (e.g. lr), value is a ClassyParamScheduler
    :var checkpoint: Serializable dict which represents state in training
    :var phases: List of phase specific information, e.g. if phase is
        train / test.
    :var hooks: List of hooks to apply during training
    :var train: Phase type, if true it means we are training,
        false means testing
    :var distributed_model: Base model, but wrapped in DDP (DistributedDataParallel)
    :var phase_idx: Current phase id, first phase is 0, if task has not started
        training then returns -1
    :var train_phase_idx: Only counts train phases
    :var num_updates: Number of total parameter updates applied to model
        by the optimizer
    :var data_iterator: Iterator which can be used to obtain batches
    :var losses: Loss curve
    :var perf_log: list of training speed measurements, to be logged
    :var clip_grad_norm: maximum gradient norm (default None)
    :var simulated_global_batchsize: batch size simulated via gradient accumulation
    :var optimizer_period: apply optimizer after this many steps; derived from
        simulated_global_batchsize, default 1.
    """
    def __init__(self):
        """Constructs a ClassificationTask"""
        super().__init__()

        self.base_loss = None
        self.datasets = {}
        self.meters = []
        self.num_epochs = 1
        self.test_phase_period = 1
        self.train_phases_per_epoch = 0
        self.test_only = False
        self.base_model = None
        self.optimizer = None
        self.optimizer_schedulers = {}
        self.checkpoint_dict = None
        self.checkpoint_path = None
        self.checkpoint_load_strict = True
        self.phases = []
        self.hooks = []
        self.train = True
        self.distributed_model = None
        self.distributed_loss = None
        self.phase_idx = -1
        self.train_phase_idx = -1
        self.num_updates = 0
        self.dataloader = None
        self.data_iterator = None
        self.losses = []
        self.broadcast_buffers_mode: BroadcastBuffersMode = (
            BroadcastBuffersMode.BEFORE_EVAL)
        self.amp_args = None
        self.amp_type = None
        self.amp_grad_scaler = None
        self.mixup_transform = None
        self.perf_log = []
        self.last_batch = None
        self.batch_norm_sync_mode = BatchNormSyncMode.DISABLED
        self.find_unused_parameters = False
        self.use_gpu = torch.cuda.is_available()
        self.dataloader_mp_context = "spawn"
        self.bn_weight_decay = False
        self._train_only = True
        self.clip_grad_norm = None
        self.simulated_global_batchsize = None
        self.optimizer_period = 1
        self.ddp_bucket_cap_mb = 25
        self.use_sharded_ddp = False
        self.fp16_grad_compress = False

    def set_use_sharded_ddp(self, use_sharded_ddp: bool):
        self.use_sharded_ddp = use_sharded_ddp
        if self.use_sharded_ddp:
            logging.info("Using Sharded DDP")
        return self

    def set_use_gpu(self, use_gpu: bool):
        self.use_gpu = use_gpu

        assert (not self.use_gpu
                or torch.cuda.is_available()), "CUDA required to train on GPUs"

        return self

    def set_clip_grad_norm(self, clip_grad_norm: Optional[float]):
        """Sets maximum gradient norm.

        None means gradient clipping is disabled. Defaults to None."""
        self.clip_grad_norm = clip_grad_norm
        if clip_grad_norm is None:
            logging.info("Disabled gradient norm clipping.")
        else:
            logging.info(
                f"Enabled gradient norm clipping with threshold: {clip_grad_norm}"
            )
        return self

    def set_simulated_global_batchsize(
            self, simulated_global_batchsize: Optional[int]):
        """Sets a simulated batch size by gradient accumulation.

        Gradient accumulation adds up gradients from multiple minibatches and
        steps the optimizer every N train_steps, where N is optimizer_period.
        When enabled, the very last train_steps might end up not updating the
        model, depending on the number of total steps. None means gradient
        accumulation is disabled. Defaults to None."""
        self.simulated_global_batchsize = simulated_global_batchsize
        return self

    def set_checkpoint(self, checkpoint_path: str):
        """Sets checkpoint on task.

        Args:
            checkpoint_path: The path to load the checkpoint from. Can be a file or a
            directory. See :func:`load_checkpoint` for more information.
        """
        self.checkpoint_path = checkpoint_path
        return self

    def set_checkpoint_load_strict(self, checkpoint_load_strict: bool):
        """Sets checkpoint on task.

        Args:
            checkpoint_load_strict: Whether to use load_strict when copying model weights
        """
        self.checkpoint_load_strict = checkpoint_load_strict
        return self

    def _set_checkpoint_dict(self, checkpoint_dict: Dict[str, Any]):
        """Sets the checkpoint dict in the task. Only used for testing.

        Args:
            checkpoint_dict: A serializable dict representing current task state
        """
        self.checkpoint_dict = checkpoint_dict
        return self

    def set_num_epochs(self, num_epochs: Union[int, float]):
        """Set number of epochs to be run.

        Args:
           num_epochs: Number of epochs to run task
        """
        self.num_epochs = num_epochs
        return self

    def set_test_phase_period(self, test_phase_period: int):
        """Set the period of test phase.

        Args:
            test_phase_period: The period of test phase
        """
        self.test_phase_period = test_phase_period
        return self

    def set_dataset(self, dataset: ClassyDataset, phase_type: str):
        """Set dataset for phase type on task

        Args:
            dataset: ClassyDataset for returning samples.
            phase_type: str must be one of "train" or "test"
        """
        assert phase_type in [
            "train",
            "test",
        ], "phase_type must be in ['train', 'test']"
        self.datasets[phase_type] = dataset
        if phase_type == "train":
            self.train_phases_per_epoch = getattr(dataset, "phases_per_epoch",
                                                  1)
        else:
            self._train_only = False
        return self

    def set_dataloader_mp_context(self, dataloader_mp_context: Optional[str]):
        """Set the multiprocessing context used by the dataloader.

        The context can be either 'spawn', 'fork', 'forkserver' or None (uses the
        default context). See
        https://docs.python.org/3/library/multiprocessing.html#multiprocessing.get_context
        for more details."""

        self.dataloader_mp_context = dataloader_mp_context
        return self

    def set_optimizer(self, optimizer: ClassyOptimizer):
        """Set optimizer for task

        Args:
            optimizer: optimizer for task
        """
        self.optimizer = optimizer
        return self

    def set_loss(self, loss: ClassyLoss):
        """Set loss function for task

        Args:
            loss: loss for task
        """
        self.base_loss = loss
        return self

    def set_meters(self, meters: List["ClassyMeter"]):
        """Set meters for task

        Args:
            meters: list of meters to compute during training
        """
        self.meters = meters
        return self

    def set_distributed_options(
        self,
        broadcast_buffers_mode: BroadcastBuffersMode = BroadcastBuffersMode.
        BEFORE_EVAL,
        batch_norm_sync_mode: BatchNormSyncMode = BatchNormSyncMode.DISABLED,
        batch_norm_sync_group_size: int = 0,
        find_unused_parameters: bool = False,
        bucket_cap_mb: int = 25,
        fp16_grad_compress: bool = False,
    ):
        """Set distributed options.

        Args:
            broadcast_buffers_mode: Broadcast buffers mode. See
                :class:`BroadcastBuffersMode` for options.
            batch_norm_sync_mode: Batch normalization synchronization mode. See
                :class:`BatchNormSyncMode` for options.
            batch_norm_sync_group_size: Group size to use for synchronized batch norm.
                0 means that the stats are synchronized across all replicas. For
                efficient synchronization, set it to the number of GPUs in a node (
                usually 8).
            find_unused_parameters: See
                :class:`torch.nn.parallel.DistributedDataParallel` for information.
            bucket_cap_mb: See
                :class:`torch.nn.parallel.DistributedDataParallel` for information.
        Raises:
            RuntimeError: If batch_norm_sync_mode is `BatchNormSyncMode.APEX` and apex
                is not installed.
        """
        self.broadcast_buffers_mode = broadcast_buffers_mode

        if batch_norm_sync_group_size > 0:
            if not batch_norm_sync_mode == BatchNormSyncMode.APEX:
                # this should ideally work with PyTorch Sync BN as well, but it
                # fails while initializing DDP for some reason.
                raise ValueError(
                    "batch_norm_sync_group_size can be > 0 only when "
                    "Apex Synchronized Batch Normalization is being used.")
        self.batch_norm_sync_group_size = batch_norm_sync_group_size

        if batch_norm_sync_mode == BatchNormSyncMode.DISABLED:
            logging.info("Synchronized Batch Normalization is disabled")
        else:
            if batch_norm_sync_mode == BatchNormSyncMode.APEX and not apex_available:
                raise RuntimeError("apex is not installed")
            msg = f"Using Synchronized Batch Normalization using {batch_norm_sync_mode}"
            if self.batch_norm_sync_group_size > 0:
                msg += f" and group size {batch_norm_sync_group_size}"
            logging.info(msg)
        self.batch_norm_sync_mode = batch_norm_sync_mode

        if find_unused_parameters:
            logging.info("Enabling find_unused_parameters in DDP")

        self.find_unused_parameters = find_unused_parameters
        self.ddp_bucket_cap_mb = bucket_cap_mb

        if fp16_grad_compress:
            if get_torch_version() < [1, 8]:
                raise RuntimeError(
                    "FP16 grad compression is only supported since PyTorch 1.8"
                )
            logging.info("Enabling FP16 grad compression")
        self.fp16_grad_compress = fp16_grad_compress

        return self

    def set_hooks(self, hooks: List["ClassyHook"]):
        """Set hooks for task

        Args:
            hooks: List of hooks to apply during training
        """
        from classy_vision.hooks import ClassyHook

        assert isinstance(hooks, list)
        assert all(isinstance(hook, ClassyHook) for hook in hooks)
        assert len({
            hook.name()
            for hook in hooks
        }) == len(hooks), "Cannot have repeated hooks of the same class"
        # TODO (zyan3): we move checkpoint hook to the end of the list because some hooks
        # may change the state of the model, and we want to save changed state in the checkpoint.
        # This is temporary fix.
        non_checkpoint_hooks = [
            hook for hook in hooks if not isinstance(hook, CheckpointHook)
        ]
        checkpoint_hooks = [
            hook for hook in hooks if isinstance(hook, CheckpointHook)
        ]
        hooks = non_checkpoint_hooks + checkpoint_hooks
        self.hooks = hooks
        return self

    def set_model(self, model: ClassyModel):
        """Set model for task

        Args:
            model: Model to be trained
        """
        self.base_model = model
        return self

    def set_test_only(self, test_only: bool):
        """Set test only flag

        Args:
            test_only: If true, only test phases will be run
        """
        self.test_only = test_only
        return self

    def set_bn_weight_decay(self, bn_weight_decay: bool):
        assert type(bn_weight_decay) == bool

        self.bn_weight_decay = bn_weight_decay
        return self

    def set_amp_args(self, amp_args: Optional[Dict[str, Any]]):
        """Disable / enable apex.amp and set the automatic mixed precision parameters.

        apex.amp can be utilized for mixed / half precision training.

        Args:
            amp_args: Dictionary containing arguments to be passed to
            amp.initialize. Set to None to disable amp.  To enable mixed
            precision training, pass amp_args={"opt_level": "O1"} here.
            See https://nvidia.github.io/apex/amp.html for more info.

        Raises:
            RuntimeError: If opt_level is not None and apex is not installed.

        Warning: apex needs to be installed to utilize this feature.
        """
        self.amp_args = amp_args

        if amp_args is None:
            logging.info("AMP disabled")
        else:
            # Check that the requested AMP type is known
            try:
                self.amp_type = AmpType[self.amp_args["amp_type"].upper()]
            except KeyError:
                logging.info("AMP type not specified, defaulting to Apex")
                self.amp_type = AmpType.APEX

            # Check for CUDA availability, required for both Apex and Pytorch AMP
            if not torch.cuda.is_available():
                raise RuntimeError(
                    "AMP is required but CUDA is not supported, cannot enable AMP"
                )

            # Check for Apex availability
            if self.amp_type == AmpType.APEX and not apex_available:
                raise RuntimeError(
                    "Apex AMP is required but Apex is not installed, cannot enable AMP"
                )

            if self.use_sharded_ddp:
                if self.amp_type == AmpType.APEX:
                    raise RuntimeError(
                        "ShardedDDP has been requested, which is incompatible with Apex AMP"
                    )

                if not fairscale_available:
                    raise RuntimeError(
                        "ShardedDDP has been requested, but fairscale is not installed in the current environment"
                    )

            # Set Torch AMP grad scaler, used to prevent gradient underflow
            elif self.amp_type == AmpType.PYTORCH:

                if self.use_sharded_ddp:
                    logging.info(
                        "Using ShardedGradScaler to manage Pytorch AMP")
                    self.amp_grad_scaler = ShardedGradScaler()
                else:
                    self.amp_grad_scaler = TorchGradScaler()

            logging.info(f"AMP enabled with args {amp_args}")
        return self

    def set_mixup_transform(self, mixup_transform: Optional["MixupTransform"]):
        """Disable / enable mixup transform for data augmentation

        Args::
            mixup_transform: a callable object which performs mixup data augmentation
        """
        self.mixup_transform = mixup_transform
        if mixup_transform is None:
            logging.info("mixup disabled")
        else:
            logging.info("mixup enabled")
        return self

    def set_optimizer_schedulers(self, schedulers):
        self.optimizer_schedulers = schedulers
        return self

    @classmethod
    def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
        """Instantiates a ClassificationTask from a configuration.

        Args:
            config: A configuration for a ClassificationTask.
                See :func:`__init__` for parameters expected in the config.

        Returns:
            A ClassificationTask instance.
        """
        test_only = config.get("test_only", False)
        if not test_only:
            # TODO Make distinction between epochs and phases in optimizer clear
            train_phases_per_epoch = config["dataset"]["train"].get(
                "phases_per_epoch", 1)

            optimizer_config = config["optimizer"]
            optimizer_config["num_epochs"] = (config["num_epochs"] *
                                              train_phases_per_epoch)
            optimizer = build_optimizer(optimizer_config)
            param_schedulers = build_optimizer_schedulers(optimizer_config)

        datasets = {}
        phase_types = ["train", "test"]
        for phase_type in phase_types:
            if phase_type in config["dataset"]:
                datasets[phase_type] = build_dataset(
                    config["dataset"][phase_type])
        loss = build_loss(config["loss"])
        amp_args = config.get("amp_args")
        meters = build_meters(config.get("meters", {}))
        model = build_model(config["model"])

        mixup_transform = None
        if config.get("mixup") is not None:
            assert "alpha" in config[
                "mixup"], "key alpha is missing in mixup dict"
            mixup_transform = MixupTransform(
                config["mixup"]["alpha"],
                num_classes=config["mixup"].get("num_classes"),
                cutmix_alpha=config["mixup"].get("cutmix_alpha", 0),
                cutmix_minmax=config["mixup"].get("cutmix_minmax"),
                mix_prob=config["mixup"].get("mix_prob", 1.0),
                switch_prob=config["mixup"].get("switch_prob", 0.5),
                mode=config["mixup"].get("mode", "batch"),
                label_smoothing=config["mixup"].get("label_smoothing", 0.0),
            )

        # hooks config is optional
        hooks_config = config.get("hooks")
        hooks = []
        if hooks_config is not None:
            hooks = build_hooks(hooks_config)

        distributed_config = config.get("distributed", {})
        distributed_options = {
            "broadcast_buffers_mode":
            BroadcastBuffersMode[distributed_config.get(
                "broadcast_buffers", "before_eval").upper()],
            "batch_norm_sync_mode":
            BatchNormSyncMode[distributed_config.get("batch_norm_sync_mode",
                                                     "disabled").upper()],
            "batch_norm_sync_group_size":
            distributed_config.get("batch_norm_sync_group_size", 0),
            "find_unused_parameters":
            distributed_config.get("find_unused_parameters", False),
            "bucket_cap_mb":
            distributed_config.get("bucket_cap_mb", 25),
            "fp16_grad_compress":
            distributed_config.get("fp16_grad_compress", False),
        }

        task = (
            cls().set_num_epochs(config["num_epochs"]).set_test_phase_period(
                config.get(
                    "test_phase_period",
                    1)).set_loss(loss).set_test_only(test_only).set_model(
                        model).set_meters(meters).set_amp_args(amp_args).
            set_mixup_transform(mixup_transform).set_distributed_options(
                **distributed_options).set_hooks(hooks).set_bn_weight_decay(
                    config.get("bn_weight_decay", False)).set_clip_grad_norm(
                        config.get("clip_grad_norm")).
            set_simulated_global_batchsize(
                config.get("simulated_global_batchsize")).set_use_sharded_ddp(
                    config.get("use_sharded_ddp", False)))

        if not test_only:
            task.set_optimizer(optimizer)
            task.set_optimizer_schedulers(param_schedulers)

        use_gpu = config.get("use_gpu")
        if use_gpu is not None:
            task.set_use_gpu(use_gpu)

        for phase_type in datasets:
            task.set_dataset(datasets[phase_type], phase_type)

        # NOTE: this is a private member and only meant to be used for
        # logging/debugging purposes. See __repr__ implementation
        task._config = config

        return task

    @property
    def num_batches_per_phase(self):
        """Returns number of batches in current phase iterator"""
        return len(self.data_iterator)

    @property
    def model(self):
        """Returns model used in training (can be wrapped with DDP)"""
        return (self.distributed_model
                if is_distributed_training_run() else self.base_model)

    @property
    def loss(self):
        """Returns loss used in training (can be wrapped with DDP)"""
        return self.distributed_loss if self.distributed_loss else self.base_loss

    @property
    def phase_type(self):
        """Returns current phase type. String with value "train" or "test" """
        return "train" if self.train else "test"

    @property
    def eval_phase_idx(self):
        """Returns current evaluation phase"""
        return self.phase_idx - self.train_phase_idx - 1

    def get_total_training_phases(self):
        """
        Returns the total number of "train" phases in the task
        """
        num_training_phases = 0
        for phase in self.phases:
            if phase["train"] is True:
                num_training_phases += 1
        return num_training_phases

    def get_total_test_phases(self):
        """
        Returns the total number of "test" phases in the task
        """
        num_test_phases = 0
        for phase in self.phases:
            if phase["train"] is False:
                num_test_phases += 1
        return num_test_phases

    def _build_phases(self):
        """Returns list of phases from config.

        These phases will look like:
        {
          train: is this a train or test phase?
          optimizer: optimizer settings
        }

        - If this is a test only run, then only test phases will be
        generated
        - If this is a training run with both train and test datasets, then x phases =
          x train phases + x test phases, interleaved. If test_phase_period > 1, test
          phases are only added after test_phase_period train phases. The last phase is
          always a test phase.
        - If this is a training run with only a train dataset, then x phases = x train
          phases.
        """
        if not self.test_only:
            phases = [{
                "train": True
            } for _ in range(
                math.ceil(self.train_phases_per_epoch * self.num_epochs))]

            if self._train_only:
                return phases

            final_phases = []
            for i, phase in enumerate(phases):
                final_phases.append(phase)
                if (i + 1) % self.test_phase_period == 0:
                    final_phases.append({"train": False})
            if final_phases[-1]["train"]:
                final_phases.append({"train": False})
            return final_phases

        return [{"train": False} for _ in range(self.num_epochs)]

    def build_dataloader_from_dataset(self, dataset, **kwargs):
        """Builds a dataloader from the provided dataset

        Args:
            dataset: A ClassyDataset
            kwargs: Additional kwargs to pass during dataloader construction for
                derived classes
        """
        return dataset.iterator(
            phase_type=self.phase_type,
            current_phase_id=self.train_phase_idx if self.train else 0,
            pin_memory=self.use_gpu and torch.cuda.device_count() > 1,
            multiprocessing_context=mp.get_context(self.dataloader_mp_context),
            **kwargs,
        )

    def build_dataloaders_for_current_phase(self):
        """Builds dataloader(s) for the current phase.

        Deriving classes can override this method to support custom behavior, like
        supporting multiple dataloaders in parallel.
        """
        self.dataloader = self.build_dataloader_from_dataset(
            self.datasets[self.phase_type])

    def prepare_optimizer(self, optimizer, model, loss=None):
        bn_params, other_params = split_batchnorm_params(model)
        if loss is not None:
            bn_params_loss, params_loss = split_batchnorm_params(loss)
            bn_params = bn_params + bn_params_loss
            other_params = other_params + params_loss

        bn_schedulers = self.optimizer_schedulers.copy()
        if not self.bn_weight_decay:
            bn_schedulers["weight_decay"] = 0

        param_groups = [{"params": other_params, **self.optimizer_schedulers}]
        if len(bn_params) > 0:
            param_groups.append({"params": bn_params, **bn_schedulers})
        self.optimizer.set_param_groups(param_groups)

    def prepare(self):
        """Prepares task for training, populates all derived attributes"""

        self.phases = self._build_phases()
        self.train = False if self.test_only else self.train

        if self.batch_norm_sync_mode == BatchNormSyncMode.PYTORCH:
            self.base_model = nn.SyncBatchNorm.convert_sync_batchnorm(
                self.base_model)
        elif self.batch_norm_sync_mode == BatchNormSyncMode.APEX:
            sync_bn_process_group = apex.parallel.create_syncbn_process_group(
                self.batch_norm_sync_group_size)
            self.base_model = apex.parallel.convert_syncbn_model(
                self.base_model, process_group=sync_bn_process_group)

        # move the model and loss to the right device
        if self.use_gpu:
            self.base_model, self.base_loss = copy_model_to_gpu(
                self.base_model, self.base_loss)
        else:
            self.base_loss.cpu()
            self.base_model.cpu()

        if self.optimizer is not None:
            self.prepare_optimizer(optimizer=self.optimizer,
                                   model=self.base_model,
                                   loss=self.base_loss)

        if self.amp_args is not None:
            if self.amp_type == AmpType.APEX:
                # Initialize apex.amp. This updates the model and the PyTorch optimizer (
                # if training, which is wrapped by the ClassyOptimizer in self.optimizer).
                # Please note this must happen before loading the checkpoint, cause
                # there's amp state to be restored.
                if self.optimizer is None:
                    self.base_model = apex.amp.initialize(self.base_model,
                                                          optimizers=None,
                                                          **self.amp_args)
                else:
                    self.base_model, self.optimizer.optimizer = apex.amp.initialize(
                        self.base_model, self.optimizer.optimizer,
                        **self.amp_args)

        if self.simulated_global_batchsize is not None:
            if self.simulated_global_batchsize % self.get_global_batchsize(
            ) != 0:
                raise ValueError(
                    f"Global batch size ({self.get_global_batchsize()}) must divide "
                    f"simulated_global_batchsize ({self.simulated_global_batchsize})"
                )
        else:
            self.simulated_global_batchsize = self.get_global_batchsize()

        self.optimizer_period = (self.simulated_global_batchsize //
                                 self.get_global_batchsize())
        if self.optimizer_period > 1:
            logging.info(
                f"Using gradient accumulation with a period of {self.optimizer_period}"
            )

        if self.checkpoint_path:
            self.checkpoint_dict = load_and_broadcast_checkpoint(
                self.checkpoint_path)

        classy_state_dict = (None if self.checkpoint_dict is None else
                             self.checkpoint_dict["classy_state_dict"])

        if classy_state_dict is not None:
            state_load_success = update_classy_state(self, classy_state_dict)
            assert (state_load_success
                    ), "Update classy state from checkpoint was unsuccessful."

        self.init_distributed_data_parallel_model()

    def init_distributed_data_parallel_model(self):
        """
        Initialize
        `torch.nn.parallel.distributed.DistributedDataParallel <https://pytorch.org/
        docs/stable/nn.html#distributeddataparallel>`_.

        Needed for distributed training. This is where a model should be wrapped by DDP.
        """
        if not is_distributed_training_run():
            return
        assert (self.distributed_model is
                None), "init_ddp_non_elastic must only be called once"

        broadcast_buffers = (
            self.broadcast_buffers_mode == BroadcastBuffersMode.FORWARD_PASS)

        if self.use_sharded_ddp:
            if not isinstance(self.optimizer, ZeRO):
                raise ValueError(
                    "ShardedDataParallel engine should only be used in conjunction with ZeRO optimizer"
                )
            from fairscale.nn.data_parallel import ShardedDataParallel

            # Replace the original DDP wrap by the shard-aware ShardedDDP
            self.distributed_model = ShardedDataParallel(
                module=self.base_model,
                sharded_optimizer=self.optimizer.optimizer,
                broadcast_buffers=broadcast_buffers,
            )
        else:
            self.distributed_model = init_distributed_data_parallel_model(
                self.base_model,
                broadcast_buffers=broadcast_buffers,
                find_unused_parameters=self.find_unused_parameters,
                bucket_cap_mb=self.ddp_bucket_cap_mb,
            )
            if self.fp16_grad_compress:

                from torch.distributed.algorithms import ddp_comm_hooks

                # FP16 hook is stateless and only takes a process group as the state.
                # We use the default process group so we set the state to None.
                process_group = None
                self.distributed_model.register_comm_hook(
                    process_group,
                    ddp_comm_hooks.default_hooks.fp16_compress_hook)
        if (isinstance(self.base_loss, ClassyLoss)
                and self.base_loss.has_learned_parameters()):
            logging.info("Initializing distributed loss")
            self.distributed_loss = init_distributed_data_parallel_model(
                self.base_loss,
                broadcast_buffers=broadcast_buffers,
                find_unused_parameters=self.find_unused_parameters,
                bucket_cap_mb=self.ddp_bucket_cap_mb,
            )

    @property
    def where(self):
        """Returns the proportion of training that has completed. If in test
        only mode, returns proportion of testing completed

        Returned value is a float in the range [0, 1)
        """
        current_step = self.num_updates / self.get_global_batchsize()
        num_phases = (self.get_total_test_phases()
                      if self.test_only else self.get_total_training_phases())

        if self.num_batches_per_phase <= 0:
            raise RuntimeError("No batches to read. Is the dataset empty?")

        num_steps = num_phases * self.num_batches_per_phase
        where = current_step / num_steps

        return where

    def get_classy_state(self, deep_copy: bool = False):
        """Returns serialiable state of task

        Args:
            deep_copy: If true, does a deep copy of state before returning.
        """
        optimizer_state = {}
        if self.optimizer is not None:
            optimizer_state = self.optimizer.get_classy_state()

        classy_state_dict = {
            "train": self.train,
            "base_model": self.base_model.get_classy_state(),
            "meters": [meter.get_classy_state() for meter in self.meters],
            "optimizer": optimizer_state,
            "phase_idx": self.phase_idx,
            "train_phase_idx": self.train_phase_idx,
            "num_updates": self.num_updates,
            "losses": self.losses,
            "hooks":
            {hook.name(): hook.get_classy_state()
             for hook in self.hooks},
            "loss": {},
        }
        if "train" in self.datasets and self._is_checkpointable_dataset(
                self.datasets["train"]):
            classy_state_dict["train_dataset_iterator"] = self.datasets[
                "train"].get_classy_state()

        if isinstance(self.base_loss, ClassyLoss):
            classy_state_dict["loss"] = self.base_loss.get_classy_state()
        if self.amp_args is not None:
            if self.amp_type == AmpType.APEX:
                classy_state_dict["amp"] = apex.amp.state_dict()

            elif self.amp_grad_scaler is not None:
                classy_state_dict["amp"] = self.amp_grad_scaler.state_dict()

        if deep_copy:
            classy_state_dict = copy.deepcopy(classy_state_dict)
        return classy_state_dict

    def set_classy_state(self, state):
        """Set task state

        Args:
            state: Dict containing state of a task
        """
        self.train = False if self.test_only else state["train"]
        self.base_model.set_classy_state(state["base_model"])

        if self.test_only:
            # if we're only testing, just need the state of the model to be updated
            return

        self.phase_idx = state["phase_idx"]
        self.num_updates = state["num_updates"]
        self.train_phase_idx = state["train_phase_idx"]
        self.losses = state["losses"]
        for meter, meter_state in zip(self.meters, state["meters"]):
            meter.set_classy_state(meter_state)

        if self.optimizer is not None:
            self.optimizer.set_classy_state(state["optimizer"])
        if state.get("loss") and isinstance(self.base_loss, ClassyLoss):
            self.base_loss.set_classy_state(state["loss"])

        if "amp" in state:
            if self.amp_type == AmpType.APEX:
                apex.amp.load_state_dict(state["amp"])
            else:
                self.amp_grad_scaler.load_state_dict(state["amp"])

        for hook in self.hooks:
            # we still want to be able to run when new hooks are added or old
            # hooks are removed
            if hook.name() in state["hooks"]:
                hook.set_classy_state(state["hooks"][hook.name()])
            else:
                logging.warning(f"No state found for hook: {hook.name()}")

        if "train" in self.datasets and self._is_checkpointable_dataset(
                self.datasets["train"]):
            self.datasets["train"].set_classy_state(
                state.get("train_dataset_iterator"))

    @staticmethod
    def _is_checkpointable_dataset(dataset):
        return hasattr(dataset, "get_classy_state") and hasattr(
            dataset, "set_classy_state")

    def eval_step(self):
        self.last_batch = None

        # Process next sample
        with Timer() as timer:
            sample = next(self.data_iterator)

        assert isinstance(
            sample, dict) and "input" in sample and "target" in sample, (
                f"Returned sample [{sample}] is not a map with 'input' and" +
                "'target' keys")

        target = sample["target"]
        if self.use_gpu:
            sample = recursive_copy_to_gpu(sample, non_blocking=True)

        # Optional Pytorch AMP context
        torch_amp_context = (torch.cuda.amp.autocast() if self.amp_type
                             == AmpType.PYTORCH else contextlib.suppress())

        with torch.no_grad(), torch_amp_context:
            output = self.model(sample["input"])

            local_loss = self.compute_loss(output, sample)

            loss = local_loss.detach().clone()

            self.losses.append(loss.data.cpu().item())

            self.update_meters(output, sample)

        # Move some data to the task so hooks get a chance to access it
        self.last_batch = LastBatchInfo(
            loss=loss,
            output=output,
            target=target,
            sample=sample,
            step_data={"sample_fetch_time": timer.elapsed_time},
        )

    def check_inf_nan(self, loss):
        if loss == float("inf") or loss == float("-inf") or loss != loss:
            raise FloatingPointError(f"Loss is infinity or NaN: {loss}")

    def _should_do_step(self):
        """Tells if we will be performing an optimizer step.

        Returns True always if there is no gradient accumulation. With gradient
        accumulation returns True only when the gradients will be synchronized and we
        will be performing an optimizer step.
        """
        update_idx = self.num_updates // self.get_global_batchsize()
        return (update_idx %
                self.optimizer_period) == self.optimizer_period - 1

    def train_step(self):
        """Train step to be executed in train loop."""

        self.last_batch = None

        # Process next sample
        with Timer() as timer:
            sample = next(self.data_iterator)

        assert isinstance(
            sample, dict) and "input" in sample and "target" in sample, (
                f"Returned sample [{sample}] is not a map with 'input' and" +
                "'target' keys")

        # Copy sample to GPU
        target = sample["target"]
        if self.use_gpu:
            sample = recursive_copy_to_gpu(sample, non_blocking=True)

        if self.mixup_transform is not None:
            sample = self.mixup_transform(sample)

        # Optional Pytorch AMP context
        torch_amp_context = (torch.cuda.amp.autocast() if self.amp_type
                             == AmpType.PYTORCH else contextlib.suppress())

        # only sync with DDP when we need to perform an optimizer step
        # an optimizer step can be skipped if gradient accumulation is enabled
        do_step = self._should_do_step()
        ctx_mgr_model = (self.distributed_model.no_sync()
                         if self.distributed_model is not None and not do_step
                         else contextlib.suppress())
        ctx_mgr_loss = (self.distributed_loss.no_sync()
                        if self.distributed_loss is not None and not do_step
                        else contextlib.suppress())

        with ctx_mgr_model, ctx_mgr_loss:
            # Forward pass
            with torch.enable_grad(), torch_amp_context:
                output = self.compute_model(sample)

                local_loss = self.compute_loss(output, sample)
                loss = local_loss.detach().clone()
                self.losses.append(loss.data.cpu().item())

                self.update_meters(output, sample)

            # Backwards pass + optimizer step
            self.run_optimizer(local_loss)

        self.num_updates += self.get_global_batchsize()

        # Move some data to the task so hooks get a chance to access it
        self.last_batch = LastBatchInfo(
            loss=loss,
            output=output,
            target=target,
            sample=sample,
            step_data={"sample_fetch_time": timer.elapsed_time},
        )

    def compute_model(self, sample):
        return self.model(sample["input"])

    def compute_loss(self, model_output, sample):
        return self.loss(model_output, sample["target"])

    def run_optimizer(self, loss):
        """Runs backwards pass and update the optimizer"""

        self.check_inf_nan(loss)

        # Gradient accumulation logic. We always set optimizer_period, even
        # if gradient accumulation is disabled. Assumes all batches have the
        # same size
        update_idx = self.num_updates // self.get_global_batchsize()
        do_zero_grad = (update_idx % self.optimizer_period) == 0
        do_step = self._should_do_step()

        if do_zero_grad:
            self.optimizer.zero_grad()

        if self.amp_type == AmpType.APEX:
            with apex.amp.scale_loss(loss,
                                     self.optimizer.optimizer) as scaled_loss:
                scaled_loss.backward()
        elif self.amp_type == AmpType.PYTORCH:
            self.amp_grad_scaler.scale(loss).backward()
        else:
            loss.backward()

        if do_step:
            # Handle gradient accumulation related gradient rescaling
            if self.optimizer_period != 1:
                self._rescale_gradients(1 / self.optimizer_period)

            # Clipping must happen after grad accumulation
            if self.clip_grad_norm is not None:
                self._clip_gradients(self.clip_grad_norm)

            if self.amp_type == AmpType.PYTORCH:
                # If using mixed precision, handle underflow-related scaling
                # See https://pytorch.org/docs/stable/amp.html#gradient-scaling
                # for context
                self.amp_grad_scaler.step(self.optimizer, where=self.where)
                self.amp_grad_scaler.update()
            else:
                self.optimizer.step(where=self.where)

    def _rescale_gradients(self, scale):
        for param in master_params(self.optimizer):
            if param.grad is not None:
                param.grad.data.mul_(scale)

    def _clip_gradients(self, max_norm):
        nn.utils.clip_grad_norm_(master_params(self.optimizer), max_norm)

    def update_meters(self, model_output, sample):
        target = sample["target"].detach().cpu()
        model_output = model_output.detach().cpu()

        # Update meters
        for meter in self.meters:
            meter.update(model_output, target, is_train=self.train)

    def synchronize_losses(self):
        """Average the losses across the different replicas"""

        # Average losses across nodes
        losses_tensor = torch.tensor(self.losses)
        synchronized_losses_tensor = all_reduce_mean(losses_tensor)
        self.losses = synchronized_losses_tensor.tolist()

    def advance_phase(self):
        """Performs bookkeeping / task updates between phases

        Increments phase idx, resets meters, resets loss history,
        resets counters, shuffles dataset, rebuilds iterators, and
        sets the train / test state for phase.
        """
        logging.debug("Advancing phase")
        # Reset meters for next phase / epoch
        for meter in self.meters:
            meter.reset()

        # Reset loss history for next epoch
        self.losses = []

        # Setup new phase
        self.phase_idx += 1
        phase = self.phases[self.phase_idx]
        self.train = True if phase["train"] else False
        if self.train:
            self.train_phase_idx += 1

        # Re-build dataloader & re-create iterator anytime membership changes.
        self.build_dataloaders_for_current_phase()
        self.create_data_iterators()
        # Set up pytorch module in train vs eval mode, update optimizer.
        self._set_model_train_mode()

    def done_training(self):
        """Stop condition for training"""
        return self.phase_idx + 1 >= len(self.phases)

    def create_data_iterators(self):
        """Creates data iterator(s) for the current phase."""
        # Delete iterator explicitly so that all dataloader processes
        # are cleaned up.
        del self.data_iterator
        self.data_iterator = iter(self.dataloader)

    def _set_model_train_mode(self):
        """Set train mode for model"""
        phase = self.phases[self.phase_idx]
        self.base_model.train(phase["train"])
        self.base_loss.train(phase["train"])

        if (self.broadcast_buffers_mode == BroadcastBuffersMode.BEFORE_EVAL
                and not self.train):
            self._broadcast_buffers()

    def _broadcast_buffers(self):
        """Explicitly synchronize buffers across all devices."""
        if self.distributed_model is None:
            return
        buffers = list(self.base_model.buffers())
        if len(buffers) > 0:
            logging.info("Synchronizing buffers before evaluation.")
            for buffer in buffers:
                broadcast(buffer,
                          0,
                          group=self.distributed_model.process_group)

    # TODO: Functions below should be better abstracted into the dataloader
    # abstraction
    def get_batchsize_per_replica(self):
        """Return local replica's batchsize for dataset (e.g. batchsize per GPU)"""
        return self.datasets[self.phase_type].get_batchsize_per_replica()

    def get_global_batchsize(self):
        """Return global batchsize across all trainers"""
        return self.datasets[self.phase_type].get_global_batchsize()

    def on_start(self):
        for hook in self.hooks:
            hook.on_start(self)

    def on_phase_start(self):
        self.phase_start_time_total = time.perf_counter()

        self.advance_phase()

        for hook in self.hooks:
            hook.on_phase_start(self)

        self.phase_start_time_train = time.perf_counter()

    def on_phase_end(self):
        self.log_phase_end(self.phase_type)

        if self.train:
            self.optimizer.on_epoch(where=self.where)

        logging.debug("Syncing losses on phase end...")
        self.synchronize_losses()
        logging.debug("...losses synced")

        logging.debug("Syncing meters on phase end...")
        for meter in self.meters:
            meter.sync_state()
        logging.debug("...meters synced")
        barrier()

        for hook in self.hooks:
            hook.on_phase_end(self)
        self.perf_log = []

        self.log_phase_end(f"{self.phase_type}_total")

        if hasattr(self.datasets[self.phase_type], "on_phase_end"):
            self.datasets[self.phase_type].on_phase_end()

    def on_end(self):
        for hook in self.hooks:
            hook.on_end(self)

    def log_phase_end(self, tag):
        start_time = (self.phase_start_time_train if tag == self.phase_type
                      else self.phase_start_time_total)
        phase_duration = time.perf_counter() - start_time
        im_per_sec = (self.get_global_batchsize() *
                      self.num_batches_per_phase) / phase_duration
        self.perf_log.append({
            "tag": tag,
            "phase_idx": self.train_phase_idx,
            "im_per_sec": im_per_sec
        })

    def __repr__(self):
        if hasattr(self, "_config"):
            config = json.dumps(self._config, indent=4)
            return f"{super().__repr__()} initialized with config:\n{config}"

        return super().__repr__()
예제 #15
0
파일: trainer.py 프로젝트: dmizr/phuber
class Trainer:
    """Model trainer

    Args:
        model: model to train
        loss_fn: loss function
        optimizer: model optimizer
        epochs: number of epochs
        device: device to train the model on
        train_loader: training dataloader
        val_loader: validation dataloader
        scheduler: learning rate scheduler
        update_sched_on_iter: whether to call the scheduler every iter or every epoch
        grad_clip_max_norm: gradient clipping max norm (disabled if None)
        writer: writer which logs metrics to TensorBoard (disabled if None)
        save_path: folder in which to save models (disabled if None)
        checkpoint_path: path to model checkpoint, to resume training

    """

    def __init__(
        self,
        model: torch.nn.Module,
        loss_fn: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        epochs: int,
        device: torch.device,
        train_loader: DataLoader,
        val_loader: Optional[DataLoader] = None,
        scheduler: Optional = None,  # Type: torch.optim.lr_scheduler._LRScheduler
        update_sched_on_iter: bool = False,
        grad_clip_max_norm: Optional[float] = None,
        writer: Optional[SummaryWriter] = None,
        save_path: Optional[str] = None,
        checkpoint_path: Optional[str] = None,
        mixed_precision: bool = False,
    ) -> None:

        # Logging
        self.logger = logging.getLogger()
        self.writer = writer

        # Saving
        self.save_path = save_path

        # Device
        self.device = device

        # Data
        self.train_loader = train_loader
        self.val_loader = val_loader

        # Model
        self.model = model
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.update_sched_on_iter = update_sched_on_iter
        self.grad_clip_max_norm = grad_clip_max_norm
        self.epochs = epochs
        self.start_epoch = 0

        # Floating-point precision
        self.mixed_precision = (
            True if self.device.type == "cuda" and mixed_precision else False
        )
        self.scaler = GradScaler() if self.mixed_precision else None

        if checkpoint_path:
            self._load_from_checkpoint(checkpoint_path)

        # Metrics
        self.train_loss_metric = LossMetric()
        self.train_acc_metric = AccuracyMetric(k=1)

        self.val_loss_metric = LossMetric()
        self.val_acc_metric = AccuracyMetric(k=1)

    def train(self) -> None:
        """Trains the model"""
        self.logger.info("Beginning training")
        start_time = time.time()

        for epoch in range(self.start_epoch, self.epochs):
            start_epoch_time = time.time()
            if self.mixed_precision:
                self._train_loop_amp(epoch)
            else:
                self._train_loop(epoch)

            if self.val_loader is not None:
                self._val_loop(epoch)

            epoch_time = time.time() - start_epoch_time
            self._end_loop(epoch, epoch_time)

        train_time_h = (time.time() - start_time) / 3600
        self.logger.info(f"Finished training! Total time: {train_time_h:.2f}h")
        self._save_model(os.path.join(self.save_path, "final_model.pt"), self.epochs)

    def _train_loop(self, epoch: int) -> None:
        """
        Regular train loop

        Args:
            epoch: current epoch
        """
        # Progress bar
        pbar = tqdm.tqdm(total=len(self.train_loader), leave=False)
        pbar.set_description(f"Epoch {epoch} | Train")

        # Set to train
        self.model.train()

        # Loop
        for data, target in self.train_loader:
            # To device
            data, target = data.to(self.device), target.to(self.device)

            # Forward + backward
            self.optimizer.zero_grad()
            out = self.model(data)
            loss = self.loss_fn(out, target)
            loss.backward()

            if self.grad_clip_max_norm is not None:
                torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), self.grad_clip_max_norm
                )

            self.optimizer.step()

            # Update scheduler if it is iter-based
            if self.scheduler is not None and self.update_sched_on_iter:
                self.scheduler.step()

            # Update metrics
            self.train_loss_metric.update(loss.item(), data.shape[0])
            self.train_acc_metric.update(out, target)

            # Update progress bar
            pbar.update()
            pbar.set_postfix_str(f"Loss: {loss.item():.3f}", refresh=False)

        # Update scheduler if it is epoch-based
        if self.scheduler is not None and not self.update_sched_on_iter:
            self.scheduler.step()

        pbar.close()

    def _train_loop_amp(self, epoch: int) -> None:
        """
        Train loop with Automatic Mixed Precision

        Args:
            epoch: current epoch
        """
        # Progress bar
        pbar = tqdm.tqdm(total=len(self.train_loader), leave=False)
        pbar.set_description(f"Epoch {epoch} | Train")

        # Set to train
        self.model.train()

        # Loop
        for data, target in self.train_loader:
            # To device
            data, target = data.to(self.device), target.to(self.device)

            # Forward + backward
            self.optimizer.zero_grad()

            # Use amp in forward pass
            with autocast():
                out = self.model(data)
                loss = self.loss_fn(out, target)

            # Backward pass with scaler
            self.scaler.scale(loss).backward()

            # Unscale before gradient clipping
            self.scaler.unscale_(self.optimizer)

            if self.grad_clip_max_norm is not None:
                torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), self.grad_clip_max_norm
                )

            # Update optimizer and scaler
            self.scaler.step(self.optimizer)
            self.scaler.update()

            # Update scheduler if it is iter-based
            if self.scheduler is not None and self.update_sched_on_iter:
                self.scheduler.step()

            # Update metrics
            self.train_loss_metric.update(loss.item(), data.shape[0])
            self.train_acc_metric.update(out, target)

            # Update progress bar
            pbar.update()
            pbar.set_postfix_str(f"Loss: {loss.item():.3f}", refresh=False)

        # Update scheduler if it is epoch-based
        if self.scheduler is not None and not self.update_sched_on_iter:
            self.scheduler.step()

        pbar.close()

    def _val_loop(self, epoch: int) -> None:
        """
        Standard validation loop

        Args:
            epoch: current epoch
        """
        # Progress bar
        pbar = tqdm.tqdm(total=len(self.val_loader), leave=False)
        pbar.set_description(f"Epoch {epoch} | Validation")

        # Set to eval
        self.model.eval()

        # Loop
        for data, target in self.val_loader:
            with torch.no_grad():
                # To device
                data, target = data.to(self.device), target.to(self.device)

                # Forward
                out = self.model(data)
                loss = self.loss_fn(out, target)

                # Update metrics
                self.val_loss_metric.update(loss.item(), data.shape[0])
                self.val_acc_metric.update(out, target)

                # Update progress bar
                pbar.update()
                pbar.set_postfix_str(f"Loss: {loss.item():.3f}", refresh=False)

        pbar.close()

    def _end_loop(self, epoch: int, epoch_time: float):
        # Print epoch results
        self.logger.info(self._epoch_str(epoch, epoch_time))

        # Write to tensorboard
        if self.writer is not None:
            self._write_to_tb(epoch)

        # Save model
        if self.save_path is not None:
            self._save_model(os.path.join(self.save_path, "most_recent.pt"), epoch)

        # Clear metrics
        self.train_loss_metric.reset()
        self.train_acc_metric.reset()
        if self.val_loader is not None:
            self.val_loss_metric.reset()
            self.val_acc_metric.reset()

    def _epoch_str(self, epoch: int, epoch_time: float):
        s = f"Epoch {epoch} "
        s += f"| Train loss: {self.train_loss_metric.compute():.3f} "
        s += f"| Train acc: {self.train_acc_metric.compute():.3f} "
        if self.val_loader is not None:
            s += f"| Val loss: {self.val_loss_metric.compute():.3f} "
            s += f"| Val acc: {self.val_acc_metric.compute():.3f} "
        s += f"| Epoch time: {epoch_time:.1f}s"

        return s

    def _write_to_tb(self, epoch):
        self.writer.add_scalar("Loss/train", self.train_loss_metric.compute(), epoch)
        self.writer.add_scalar("Accuracy/train", self.train_acc_metric.compute(), epoch)

        if self.val_loader is not None:
            self.writer.add_scalar("Loss/val", self.val_loss_metric.compute(), epoch)
            self.writer.add_scalar("Accuracy/val", self.val_acc_metric.compute(), epoch)

    def _save_model(self, path, epoch):
        obj = {
            "epoch": epoch + 1,
            "optimizer": self.optimizer.state_dict(),
            "model": self.model.state_dict(),
            "scheduler": self.scheduler.state_dict()
            if self.scheduler is not None
            else None,
            "scaler": self.scaler.state_dict() if self.mixed_precision else None,
        }
        torch.save(obj, os.path.join(self.save_path, path))

    def _load_from_checkpoint(self, checkpoint_path: str) -> None:
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        self.model.load_state_dict(checkpoint["model"])
        self.optimizer.load_state_dict(checkpoint["optimizer"])

        self.start_epoch = checkpoint["epoch"]

        if self.scheduler:
            self.scheduler.load_state_dict(checkpoint["scheduler"])

        if self.mixed_precision and "scaler" in checkpoint:
            self.scaler.load_state_dict(checkpoint["scheduler"])

        if self.start_epoch > self.epochs:
            raise ValueError("Starting epoch is larger than total epochs")

        self.logger.info(f"Checkpoint loaded, resuming from epoch {self.start_epoch}")
예제 #16
0
    optimizer = optim.Adam([v for v in model.parameters() if v.requires_grad],
                           lr=args.lr,
                           betas=(.5, .9),
                           eps=1e-6)
    scaler = GradScaler()

    # reload checkpoint parameters
    epoch = 0
    num_samples_treated = 0
    num_batches_treated = 0
    if args.base_model is not None:
        if os.path.isfile(args.base_model):
            checkpoint = torch.load(args.base_model)
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            scaler.load_state_dict(checkpoint["scaler"])
            epoch = checkpoint['epoch'] + 1  # we start a new epoch
            num_samples_treated = checkpoint['num_samples_treated']
            num_batches_treated = checkpoint['num_batches_treated']
        else:  # tf model
            from load_tf_models import load_ssrn_from_tf, load_t2m_from_tf  # imported here so that installing tf is
            # not mandatory
            load_t2m_from_tf(model, args.base_model) if args.net == "Text2Mel" else \
                load_ssrn_from_tf(model, args.base_model)

    max_num_samples_to_train_on = num_samples_treated + args.max_num_samples_to_train_on \
        if args.max_num_samples_to_train_on is not None else 1e10  # 1e10 in case we
    # want to loop "indefinitely"

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)
예제 #17
0
class GenericTrainingManager:

    def __init__(self, params):
        self.type = None
        self.is_master = False
        self.params = params
        self.models = {}
        self.begin_time = None
        self.dataset = None
        self.paths = None
        self.latest_epoch = -1
        self.latest_batch = 0
        self.total_batch = 0
        self.latest_train_metrics = dict()
        self.latest_valid_metrics = dict()
        self.phase = None
        self.max_mem_usage_by_epoch = list()

        self.scaler = None
        self.optimizer = None
        self.lr_scheduler = None
        self.best = None
        self.writer = None

        reset_optimizer = "reset_optimizer" in self.params["training_params"] and self.params["training_params"]["reset_optimizer"]

        self.init_hardware_config()
        self.init_paths()
        self.load_dataset()
        self.load_model(reset_optimizer)

    def init_paths(self):
        ## Create output folders
        output_path = os.path.join("outputs", self.params["training_params"]["output_folder"])
        os.makedirs(output_path, exist_ok=True)
        checkpoints_path = os.path.join(output_path, "checkpoints")
        os.makedirs(checkpoints_path, exist_ok=True)
        results_path = os.path.join(output_path, "results")
        os.makedirs(results_path, exist_ok=True)

        self.paths = {
            "results": results_path,
            "checkpoints": checkpoints_path,
            "output_folder": output_path
        }

    def load_dataset(self):
        self.params["dataset_params"]["use_ddp"] = self.params["training_params"]["use_ddp"]
        self.params["dataset_params"]["batch_size"] = self.params["training_params"]["batch_size"]
        self.params["dataset_params"]["num_gpu"] = self.params["training_params"]["nb_gpu"]
        self.dataset = DatasetManager(self.params["dataset_params"])
        if self.dataset.charset:
            self.params["model_params"]["vocab_size"] = len(self.dataset.charset)

    def init_hardware_config(self):
        # Debug mode
        if self.params["training_params"]["force_cpu"]:
            self.params["training_params"]["use_ddp"] = False
            self.params["training_params"]["use_amp"] = False
        # Manage Distributed Data Parallel & GPU usage
        self.manual_seed = 1111 if "manual_seed" not in self.params["training_params"].keys() else \
        self.params["training_params"]["manual_seed"]
        self.ddp_config = {
            "master": self.params["training_params"]["use_ddp"] and self.params["training_params"]["ddp_rank"] == 0,
            "address": "localhost" if "ddp_addr" not in self.params["training_params"].keys() else self.params["training_params"]["ddp_addr"],
            "port": "11111" if "ddp_port" not in self.params["training_params"].keys() else self.params["training_params"]["ddp_port"],
            "backend": "nccl" if "ddp_backend" not in self.params["training_params"].keys() else self.params["training_params"]["ddp_backend"],
            "rank": self.params["training_params"]["ddp_rank"],
        }
        self.is_master = self.ddp_config["master"] or not self.params["training_params"]["use_ddp"]
        if self.params["training_params"]["force_cpu"]:
            self.device = "cpu"
        else:
            if self.params["training_params"]["use_ddp"]:
                self.device = torch.device(self.ddp_config["rank"])
                self.params["dataset_params"]["ddp_rank"] = self.ddp_config["rank"]
                self.launch_ddp()
            else:
                self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        # Print GPU info
        # global
        if (self.params["training_params"]["use_ddp"] and self.ddp_config["master"]) or not self.params["training_params"]["use_ddp"]:
            print("##################")
            print("Available GPUS: {}".format(self.params["training_params"]["nb_gpu"]))
            for i in range(self.params["training_params"]["nb_gpu"]):
                print("Rank {}: {} {}".format(i, torch.cuda.get_device_name(i), torch.cuda.get_device_properties(i)))
            print("##################")
        # local
        print("Local GPU:")
        if self.device != "cpu":
            print("Rank {}: {} {}".format(self.params["training_params"]["ddp_rank"], torch.cuda.get_device_name(), torch.cuda.get_device_properties(self.device)))
        else:
            print("WORKING ON CPU !\n")
        print("##################")

    def load_model(self, reset_optimizer=False):
        self.params["model_params"]["use_amp"] = self.params["training_params"]["use_amp"]
        # Instanciate Model
        for model_name in self.params["model_params"]["models"].keys():
            self.models[model_name] = self.params["model_params"]["models"][model_name](self.params["model_params"])
            self.models[model_name].to(self.device)  # To GPU or CPU

        # Instanciate optimizer
        self.reset_optimizer()
        if "lr_scheduler" in self.params["training_params"] and self.params["training_params"]["lr_scheduler"]:
            self.lr_scheduler = self.params["training_params"]["lr_scheduler"]["type"](self.optimizer, gamma=self.params["training_params"]["lr_scheduler"]["gamma"])

        self.scaler = GradScaler(enabled=self.params["training_params"]["use_amp"])

        # Load previous weights
        checkpoint = None
        if self.params["training_params"]["load_epoch"] in ("best", "last"):
            for filename in os.listdir(self.paths["checkpoints"]):
                # Continue training
                if self.params["training_params"]["load_epoch"] in filename:
                    checkpoint_path = os.path.join(self.paths["checkpoints"], filename)
                    checkpoint = torch.load(checkpoint_path)
                    self.load_save_info(checkpoint)
                    self.latest_epoch = checkpoint["epoch"]
                    self.best = checkpoint["best"]
                    self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
                    # Make model compatible with Distributed Data Parallel if used
                    if self.params["training_params"]["use_ddp"]:
                        for model_name in self.models.keys():
                            self.models[model_name] = DDP(self.models[model_name], [self.ddp_config["rank"]])
                    # Load model weights from past training
                    for model_name in self.models.keys():
                        self.models[model_name].load_state_dict(checkpoint["{}_state_dict".format(model_name)])
                    # Load optimizer state from past training
                    if not reset_optimizer:
                        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
                    # Load optimizer scheduler config from past training if used
                    if "lr_scheduler" in self.params["training_params"] and self.params["training_params"]["lr_scheduler"] and "lr_scheduler_state_dict" in checkpoint.keys():
                        self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
                    break

        # Print the number of trained epoch so far with the model
        if self.is_master:
            print("LOADED EPOCH: {}\n".format(self.latest_epoch), flush=True)

        # New training
        if not checkpoint:
            # Weights initialization
            for model_name in self.models.keys():
                self.models[model_name].apply(self.weights_init)
            # Handle transfer learning instructions
            if self.params["model_params"]["transfer_learning"]:
                # Iterates over models
                for model_name in self.params["model_params"]["transfer_learning"].keys():
                    state_dict_name, path, learnable, strict = self.params["model_params"]["transfer_learning"][model_name]
                    # Loading pretrained weights file
                    checkpoint = torch.load(path)
                    try:
                        # Load pretrained weights for model
                        self.models[model_name].load_state_dict(checkpoint["{}_state_dict".format(state_dict_name)], strict=strict)
                        print("transfered weights for {}".format(state_dict_name), flush=True)
                    except RuntimeError as e:
                        print(e, flush=True)
                        # if error, try to load each parts of the model (useful if only few layers are different)
                        for key in checkpoint["{}_state_dict".format(state_dict_name)].keys():
                            try:
                                self.models[model_name].load_state_dict({key: checkpoint["{}_state_dict".format(state_dict_name)][key]}, strict=False)
                            except RuntimeError as e:
                                print(e, flush=True)
                    # Set parameters no trainable
                    if not learnable:
                        self.set_model_learnable(self.models[model_name], False)

            # make the model compatible with Distributed Data Parallel if used
            if self.params["training_params"]["use_ddp"]:
                for model_name in self.models.keys():
                    self.models[model_name] = DDP(self.models[model_name], [self.ddp_config["rank"]])
            return

    @staticmethod
    def set_model_learnable(model, learnable=True):
        for p in list(model.parameters()):
            p.requires_grad = learnable

    def save_model(self, epoch, name, keep_weights=False):
        """
        Save model weights
        """
        if not self.is_master:
            return
        to_del = []
        for filename in os.listdir(self.paths["checkpoints"]):
            if name in filename:
                to_del.append(os.path.join(self.paths["checkpoints"], filename))
        path = os.path.join(self.paths["checkpoints"], "{}_{}.pt".format(name, epoch))
        content = {
            'optimizer_state_dict': self.optimizer.state_dict(),
            'epoch': epoch,
            "scaler_state_dict": self.scaler.state_dict(),
            'best': self.best,
        }
        if self.lr_scheduler:
            content["lr_scheduler_state_dict"] = self.lr_scheduler.state_dict()
        content = self.add_save_info(content)
        for model_name in self.models.keys():
            content["{}_state_dict".format(model_name)] = self.models[model_name].state_dict()
        torch.save(content, path)
        if not keep_weights:
            for path_to_del in to_del:
                if path_to_del != path:
                    os.remove(path_to_del)

    def reset_optimizer(self):
        """
        Reset optimizer learning rate
        """
        parameters = list()
        for model_name in self.models.keys():
            parameters += list(self.models[model_name].parameters())
        self.optimizer = self.params["training_params"]["optimizer"]["class"]\
            (parameters, **self.params["training_params"]["optimizer"]["args"])


    @staticmethod
    def weights_init(m):
        """
        Weights initialization for model training from scratch
        """
        if isinstance(m, Conv2d) or isinstance(m, Linear):
            if m.weight is not None:
                kaiming_uniform_(m.weight, nonlinearity="relu")
            if m.bias is not None:
                zeros_(m.bias)
        elif isinstance(m, InstanceNorm2d):
            if m.weight is not None:
                ones_(m.weight)
            if m.bias is not None:
                zeros_(m.bias)

    def save_params(self):
        """
        Output text file containing a summary of all hyperparameters chosen for the training
        """
        def compute_nb_params(module):
            return sum([np.prod(p.size()) for p in list(module.parameters())])

        def class_to_str_dict(my_dict):
            for key in my_dict.keys():
                if callable(my_dict[key]):
                    my_dict[key] = my_dict[key].__name__
                elif isinstance(my_dict[key], np.ndarray):
                    my_dict[key] = my_dict[key].tolist()
                elif isinstance(my_dict[key], dict):
                    my_dict[key] = class_to_str_dict(my_dict[key])
            return my_dict

        path = os.path.join(self.paths["results"], "params")
        if os.path.isfile(path):
            return
        params = copy.deepcopy(self.params)
        params = class_to_str_dict(params)
        total_params = 0
        for model_name in self.models.keys():
            current_params = compute_nb_params(self.models[model_name])
            params["model_params"]["models"][model_name] = [params["model_params"]["models"][model_name], "{:,}".format(current_params)]
            total_params += current_params
        params["model_params"]["total_params"] = "{:,}".format(total_params)

        params["hardware"] = dict()
        if self.device != "cpu":
            for i in range(self.params["training_params"]["nb_gpu"]):
                params["hardware"][str(i)] = "{} {}".format(torch.cuda.get_device_name(i), torch.cuda.get_device_properties(i))
        else:
            params["hardware"]["0"] = "CPU"
        with open(path, 'w') as f:
            json.dump(params, f, indent=4)

    def update_memory_consumption(self):
        self.max_mem_usage_by_epoch.append(torch.cuda.max_memory_allocated())
        torch.cuda.reset_max_memory_allocated()
        with open(os.path.join(self.paths["results"], "memory.txt"), 'a') as f:
            current = round(self.max_mem_usage_by_epoch[-1]/1e9, 2)
            max = round(np.max(self.max_mem_usage_by_epoch)/1e9, 2)
            min = round(np.min(self.max_mem_usage_by_epoch)/1e9, 2)
            median = round(np.median(self.max_mem_usage_by_epoch)/1e9, 2)
            mean = round(np.mean(self.max_mem_usage_by_epoch)/1e9, 2)
            f.write("E{} - Current: {} Go - Max: {} Go - Min: {} Go - Mean: {} Go - Median: {} Go\n".format(
                self.latest_epoch, current, max, min, mean, median))

    @staticmethod
    def init_metrics(metrics_name):
        """
        Initialization of the metrics specified in metrics_name
        """
        metrics = {
            "nb_samples": 0,
            "weights": 0,
            "names": list(),
            "ids": list(),
        }
        for metric_name in metrics_name:
            if metric_name == "cer":
                metrics["nb_chars"] = 0
                metrics[metric_name] = list()
                continue
            elif metric_name == "wer":
                metrics["nb_words"] = 0
            elif metric_name in ["pred", "proba", "cer_force_len"]:
                metrics[metric_name] = list()
                continue
            elif metric_name == "diff_len":
                metrics[metric_name] = None
                continue
            metrics[metric_name] = 0
        return metrics

    @staticmethod
    def update_metrics(metrics, batch_metrics):
        """
        Add batch metrics to the metrics
        """
        for key in batch_metrics.keys():
            if key in ["diff_len", ]:
                if metrics[key] is None:
                    metrics[key] = batch_metrics[key]
                else:
                    metrics[key] = np.concatenate([metrics[key], batch_metrics[key]], axis=0)
            elif key in ["pred", ]:
                if len(metrics[key]) == 0:
                    metrics[key] = batch_metrics[key]
                else:
                    for i in range(len(metrics[key])):
                        metrics[key][i] += batch_metrics[key][i]
            else:
                metrics[key] += batch_metrics[key]
        return metrics

    def get_display_values(self, metrics, metrics_name, num_batch):
        """
        format metrics values for shell display purposes
        """
        display_values = {}
        for metric_name in metrics_name:
            if metric_name in ["cer", "cer_force_len", ]:
                edit = np.sum(metrics[metric_name])
                display_values[metric_name] = round(edit / metrics["nb_chars"], 4)
            elif metric_name == "wer":
                display_values[metric_name] = round(metrics[metric_name] / metrics["nb_words"], 4)
            elif metric_name in ["f_measure", "precision", "recall", "IoU", "mAP", "pp_f_measure", "pp_precision", "pp_recall", "pp_IoU", "pp_mAP"]:
                display_values[metric_name] = round(metrics[metric_name] / metrics["weights"], 4)
            elif metric_name in ["diff_len", ]:
                display_values[metric_name] = np.round(np.mean(np.abs(metrics[metric_name])), 3)
            elif metric_name in ["time", "pred", "probas", "nb_max_len", "worst_cer", ]:
                continue
            elif metric_name in ["loss", "loss_ctc", "loss_ce", "loss_ce_end", "loss_mse"]:
                display_values[metric_name] = round(metrics[metric_name] / self.latest_batch, 4)
            else:
                display_values[metric_name] = round(metrics[metric_name] / metrics["nb_samples"], 4)
        return display_values

    def backward_loss(self, loss, retain_graph=False):
        self.scaler.scale(loss).backward(retain_graph=retain_graph)

    def step_optimizer(self):
        self.scaler.step(self.optimizer)
        self.scaler.update()

    def train(self):
        # init tensorboard file and output param summary file
        if self.is_master:
            self.writer = SummaryWriter(self.paths["results"])
            self.save_params()
        # init variables
        self.begin_time = time()
        focus_metric_name = self.params["training_params"]["focus_metric"]
        nb_epochs = self.params["training_params"]["max_nb_epochs"]
        interval_save_weights = self.params["training_params"]["interval_save_weights"]
        metrics_name = self.params["training_params"]["train_metrics"]
        display_values = None
        # perform epochs
        for num_epoch in range(self.latest_epoch+1, nb_epochs):
            self.phase = "train"
            # Check maximum training time stop condition
            if self.params["training_params"]["max_training_time"] and time() - self.begin_time > self.params["training_params"]["max_training_time"]:
                break
            # set models trainable
            for model_name in self.models.keys():
                self.models[model_name].train()
            self.latest_epoch = num_epoch
            # init epoch metrics values
            metrics = self.init_metrics(metrics_name)
            t = tqdm(self.dataset.train_loader)
            t.set_description("EPOCH {}/{}".format(num_epoch, nb_epochs))
            # iterates over mini-batch data
            for ind_batch, batch_data in enumerate(t):
                self.latest_batch = ind_batch + 1
                self.total_batch += 1
                # train on batch data and compute metrics
                batch_metrics = self.train_batch(batch_data, metrics_name)
                batch_metrics["names"] = batch_data["names"]
                batch_metrics["ids"] = batch_data["ids"]
                # Merge metrics if Distributed Data Parallel is used
                if self.params["training_params"]["use_ddp"]:
                    batch_metrics = self.merge_ddp_metrics(batch_metrics)
                # Update learning rate via scheduler if one is used
                if self.lr_scheduler and ind_batch % self.params["training_params"]["lr_scheduler"]["step_interval"] == 0:
                    self.lr_scheduler.step()
                # Add batch metrics values to epoch metrics values
                metrics = self.update_metrics(metrics, batch_metrics)
                display_values = self.get_display_values(metrics, metrics_name, ind_batch)
                t.set_postfix(values=str(display_values))
            # log metrics in tensorboard file
            if self.is_master:
                for key in display_values.keys():
                    self.writer.add_scalar('{}_{}'.format(self.params["dataset_params"]["train"]["name"], key), display_values[key], num_epoch)
            self.latest_train_metrics = display_values

            # evaluate and compute metrics for valid sets
            if self.params["training_params"]["eval_on_valid"] and num_epoch % self.params["training_params"]["eval_on_valid_interval"] == 0:
                for valid_set_name in self.dataset.valid_loaders.keys():
                    # evaluate set and compute metrics
                    eval_values = self.evaluate(valid_set_name)
                    self.latest_valid_metrics = eval_values
                    # log valid metrics in tensorboard file
                    if self.is_master:
                        for key in eval_values.keys():
                            self.writer.add_scalar('{}_{}'.format(valid_set_name, key), eval_values[key], num_epoch)
                        if valid_set_name == self.params["training_params"]["set_name_focus_metric"] and (self.best is None or \
                                (eval_values[focus_metric_name] < self.best and self.params["training_params"]["expected_metric_value"] == "low") or\
                                (eval_values[focus_metric_name] > self.best and self.params["training_params"]["expected_metric_value"] == "high")):
                            self.save_model(epoch=num_epoch, name="best")
                            self.best = eval_values[focus_metric_name]

            ## save model weights
            if self.is_master:
                self.save_model(epoch=num_epoch, name="last")
                self.update_memory_consumption()
                if interval_save_weights and num_epoch % interval_save_weights == 0:
                    self.save_model(epoch=num_epoch, name="weigths", keep_weights=True)
                self.writer.flush()

    def evaluate(self, set_name, **kwargs):
        self.phase = "eval"
        loader = self.dataset.valid_loaders[set_name]
        # Set models in eval mode
        for model_name in self.models.keys():
            self.models[model_name].eval()
        metrics_name = self.params["training_params"]["eval_metrics"]
        display_values = None
        # initialize epoch metrics
        metrics = self.init_metrics(metrics_name)
        t = tqdm(loader)
        t.set_description("Evaluation E{}".format(self.latest_epoch))
        with torch.no_grad():
            # iterate over batch data
            for ind_batch, batch_data in enumerate(t):
                self.latest_batch = ind_batch + 1
                # eval batch data and compute metrics
                batch_metrics = self.evaluate_batch(batch_data, metrics_name)
                batch_metrics["names"] = batch_data["names"]
                batch_metrics["ids"] = batch_data["ids"]
                # merge metrics values if Distributed Data Parallel is used
                if self.params["training_params"]["use_ddp"]:
                    batch_metrics = self.merge_ddp_metrics(batch_metrics)
                # add batch metrics to epoch metrics
                metrics = self.update_metrics(metrics, batch_metrics)
                display_values = self.get_display_values(metrics, metrics_name, ind_batch)
                t.set_postfix(values=str(display_values))
        return display_values

    def predict(self, custom_name, sets_list, metrics_name, output=False):
        self.phase = "predict"
        metrics_name = metrics_name.copy()
        self.dataset.generate_test_loader(custom_name, sets_list)
        loader = self.dataset.test_loaders[custom_name]
        # Set models in eval mode
        for model_name in self.models.keys():
            self.models[model_name].eval()
        pred_time_metric = False
        if "time" in metrics_name:
            metrics_name.remove("time")
            pred_time_metric = True
        # initialize epoch metrics
        metrics = self.init_metrics(metrics_name)
        t = tqdm(loader)
        t.set_description("Prediction")
        begin_time = time()
        with torch.no_grad():
            for ind_batch, batch_data in enumerate(t):
                # iterates over batch data
                self.latest_batch = ind_batch + 1
                # eval batch data and compute metrics
                batch_metrics = self.evaluate_batch(batch_data, metrics_name)
                batch_metrics["names"] = batch_data["names"]
                batch_metrics["ids"] = batch_data["ids"]
                # merge batch metrics if Distributed Data Parallel is used
                if self.params["training_params"]["use_ddp"]:
                    batch_metrics = self.merge_ddp_metrics(batch_metrics)
                # add batch metrics to epoch metrics
                metrics = self.update_metrics(metrics, batch_metrics)
                display_values = self.get_display_values(metrics, metrics_name, ind_batch)
                t.set_postfix(values=str(display_values))
        pred_time = time() - begin_time
        # add time metric values if requested
        if pred_time_metric:
            metrics["total_time"] = np.round(pred_time, 3)
            metrics["sample_time"] = np.round(pred_time / len(self.dataset.test_datasets[custom_name]), 4)
        # output metrics values if requested
        if output:
            for name in ["probas", ]:
                if name in metrics.keys():
                    path = os.path.join(self.paths["results"], "{}_{}_{}.txt".format(name, custom_name, self.latest_epoch))
                    info = "\n".join(metrics[name])
                    with open(path, "w") as f:
                        f.write(info)
                    del metrics[name]
            self.output(metrics, custom_name)

    def launch_ddp(self):
        """
        Initialize Distributed Data Parallel system
        """
        mp.set_start_method('fork', force=True)
        os.environ['MASTER_ADDR'] = self.ddp_config["address"]
        os.environ['MASTER_PORT'] = str(self.ddp_config["port"])
        dist.init_process_group(self.ddp_config["backend"], rank=self.ddp_config["rank"], world_size=self.params["training_params"]["nb_gpu"])
        torch.cuda.set_device(self.ddp_config["rank"])
        random.seed(self.manual_seed)
        np.random.seed(self.manual_seed)
        torch.manual_seed(self.manual_seed)
        torch.cuda.manual_seed(self.manual_seed)

    def merge_ddp_metrics(self, metrics):
        """
        Merge metrics when Distributed Data Parallel is used
        """
        for metric_name in metrics.keys():
            if metric_name in ["wer", "wer_force_len", "nb_samples", "nb_words", "nb_chars", "nb_max_len",
                               "f_measure", "precision", "recall", "IoU", "mAP", "pp_f_measure", "pp_precision", "pp_recall", "pp_IoU", "pp_mAP"]:
                metrics[metric_name] = self.sum_ddp_metric(metrics[metric_name])
            elif metric_name in ["loss", "loss_ce", "loss_ctc", "loss_ce_end"]:
                metrics[metric_name] = self.sum_ddp_metric(metrics[metric_name], average=True)
            elif metric_name in ["diff_len", "cer", "cer_force_len", "ids"]:
                metrics[metric_name] = self.cat_ddp_metric(metrics[metric_name])
        return metrics

    def sum_ddp_metric(self, metric, average=False):
        """
        Sum metrics for Distributed Data Parallel
        """
        sum = torch.tensor(metric).to(self.device)
        dist.all_reduce(sum, op=dist.ReduceOp.SUM)
        if average:
            sum.true_divide(dist.get_world_size())
        return sum.item()

    def cat_ddp_metric(self, metric):
        """
        Concatenate metrics for Distributed Data Parallel
        """
        tensor = torch.tensor(metric).unsqueeze(0).to(self.device)
        res = [torch.zeros(tensor.size()).long().to(self.device) for _ in range(dist.get_world_size())]
        dist.all_gather(res, tensor)
        return list(torch.cat(res, dim=0).flatten().cpu().numpy())

    @staticmethod
    def cleanup():
        dist.destroy_process_group()

    def train_batch(self, batch_data, metric_names):
        raise NotImplementedError

    def evaluate_batch(self, batch_data, metric_names):
        raise NotImplementedError

    def output_pred(self, pred, set_name):
        raise NotImplementedError

    def add_checkpoint_info(self, load_mode="last", **kwargs):
        for filename in os.listdir(self.paths["checkpoints"]):
            if load_mode in filename:
                checkpoint_path = os.path.join(self.paths["checkpoints"], filename)
                checkpoint = torch.load(checkpoint_path)
                for key in kwargs.keys():
                    checkpoint[key] = kwargs[key]
                torch.save(checkpoint, checkpoint_path)
            return
        self.save_model(self.latest_epoch, "last")

    def output(self, metrics, set_name):
        """
        Output metrics in text file
        """
        path = os.path.join(self.paths["results"], "predict_{}_{}.txt".format(set_name, self.latest_epoch))
        with open(path, "w") as f:
            for metric_name in metrics.keys():
                if metric_name in ["cer", "cer_force_len"]:
                    edit = np.sum(metrics[metric_name])
                    value = round(edit / metrics["nb_chars"], 4)
                elif metric_name in ["wer", ]:
                    value = round(metrics[metric_name] / metrics["nb_words"], 4)
                elif metric_name in ["loss_ce", ]:
                    value = round(metrics[metric_name] / metrics["nb_samples"], 4)
                elif metric_name in ["total_time", "sample_time", "total_output_time", "sample_output_time"]:
                    value = metrics[metric_name]
                elif metric_name in ["nb_samples", "nb_words", "nb_chars", "nb_max_len"]:
                    value = metrics[metric_name]
                elif metric_name in ["diff_len", ]:
                    f.write("{}: {}\n".format(metric_name, sorted(list(metrics[metric_name]))))
                    f.write("{}-mean_abs: {}\n".format(metric_name, np.mean(np.abs(metrics[metric_name]))))
                    continue
                elif metric_name in ["worst_cer", ]:
                    m = metric_name.split("_")[-1]
                    value = [[c, id] for c, id in zip(metrics[m], metrics["ids"])]
                    value = sorted(value, key=lambda x: x[0], reverse=True)
                    value = value[:50]
                else:
                    continue
                f.write("{}: {}\n".format(metric_name, value))

    def load_save_info(self, info_dict):
        """
        Load curriculum info from saved model info
        """
        if "curriculum_config" in info_dict.keys():
            self.dataset.train_dataset.curriculum_config = info_dict["curriculum_config"]

    def add_save_info(self, info_dict):
        """
        Add curriculum info to model info to be saved
        """
        info_dict["curriculum_config"] = self.dataset.train_dataset.curriculum_config
        return info_dict
예제 #18
0
def training(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    #===================================#
    #==============Logging==============#
    #===================================#

    logger = logging.getLogger(__name__)
    logger.setLevel(logging.DEBUG)
    handler = TqdmLoggingHandler()
    handler.setFormatter(
        logging.Formatter(" %(asctime)s - %(message)s", "%Y-%m-%d %H:%M:%S"))
    logger.addHandler(handler)
    logger.propagate = False

    #===================================#
    #============Data Load==============#
    #===================================#

    # 1) Data open
    write_log(logger, "Load data...")
    gc.disable()
    with open(os.path.join(args.preprocess_path, 'processed.pkl'), 'rb') as f:
        data_ = pickle.load(f)
        train_src_indices = data_['train_src_indices']
        valid_src_indices = data_['valid_src_indices']
        train_trg_indices = data_['train_trg_indices']
        valid_trg_indices = data_['valid_trg_indices']
        src_word2id = data_['src_word2id']
        trg_word2id = data_['trg_word2id']
        src_vocab_num = len(src_word2id)
        trg_vocab_num = len(trg_word2id)
        del data_
    gc.enable()
    write_log(logger, "Finished loading data!")

    # 2) Dataloader setting
    dataset_dict = {
        'train':
        CustomDataset(train_src_indices,
                      train_trg_indices,
                      min_len=args.min_len,
                      src_max_len=args.src_max_len,
                      trg_max_len=args.trg_max_len),
        'valid':
        CustomDataset(valid_src_indices,
                      valid_trg_indices,
                      min_len=args.min_len,
                      src_max_len=args.src_max_len,
                      trg_max_len=args.trg_max_len),
    }
    dataloader_dict = {
        'train':
        DataLoader(dataset_dict['train'],
                   drop_last=True,
                   batch_size=args.batch_size,
                   shuffle=True,
                   pin_memory=True,
                   num_workers=args.num_workers),
        'valid':
        DataLoader(dataset_dict['valid'],
                   drop_last=False,
                   batch_size=args.batch_size,
                   shuffle=False,
                   pin_memory=True,
                   num_workers=args.num_workers)
    }
    write_log(
        logger,
        f"Total number of trainingsets  iterations - {len(dataset_dict['train'])}, {len(dataloader_dict['train'])}"
    )

    #===================================#
    #===========Train setting===========#
    #===================================#

    # 1) Model initiating
    write_log(logger, 'Instantiating model...')
    model = Transformer(
        src_vocab_num=src_vocab_num,
        trg_vocab_num=trg_vocab_num,
        pad_idx=args.pad_id,
        bos_idx=args.bos_id,
        eos_idx=args.eos_id,
        d_model=args.d_model,
        d_embedding=args.d_embedding,
        n_head=args.n_head,
        dim_feedforward=args.dim_feedforward,
        num_common_layer=args.num_common_layer,
        num_encoder_layer=args.num_encoder_layer,
        num_decoder_layer=args.num_decoder_layer,
        src_max_len=args.src_max_len,
        trg_max_len=args.trg_max_len,
        dropout=args.dropout,
        embedding_dropout=args.embedding_dropout,
        trg_emb_prj_weight_sharing=args.trg_emb_prj_weight_sharing,
        emb_src_trg_weight_sharing=args.emb_src_trg_weight_sharing,
        parallel=args.parallel)
    model.train()
    model = model.to(device)
    tgt_mask = model.generate_square_subsequent_mask(args.trg_max_len - 1,
                                                     device)

    # 2) Optimizer & Learning rate scheduler setting
    optimizer = optimizer_select(model, args)
    scheduler = shceduler_select(optimizer, dataloader_dict, args)
    scaler = GradScaler()

    # 3) Model resume
    start_epoch = 0
    if args.resume:
        write_log(logger, 'Resume model...')
        checkpoint = torch.load(
            os.path.join(args.save_path, 'checkpoint.pth.tar'))
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        scaler.load_state_dict(checkpoint['scaler'])
        del checkpoint

    #===================================#
    #=========Model Train Start=========#
    #===================================#

    best_val_acc = 0

    write_log(logger, 'Traing start!')

    for epoch in range(start_epoch + 1, args.num_epochs + 1):
        start_time_e = time()
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()
            if phase == 'valid':
                write_log(logger, 'Validation start...')
                val_loss = 0
                val_acc = 0
                model.eval()
            for i, (src, trg) in enumerate(
                    tqdm(dataloader_dict[phase],
                         bar_format='{l_bar}{bar:30}{r_bar}{bar:-2b}')):

                # Optimizer setting
                optimizer.zero_grad(set_to_none=True)

                # Input, output setting
                src = src.to(device, non_blocking=True)
                trg = trg.to(device, non_blocking=True)

                trg_sequences_target = trg[:, 1:]
                non_pad = trg_sequences_target != args.pad_id
                trg_sequences_target = trg_sequences_target[
                    non_pad].contiguous().view(-1)

                # Train
                if phase == 'train':

                    # Loss calculate
                    with autocast():
                        predicted = model(src,
                                          trg[:, :-1],
                                          tgt_mask,
                                          non_pad_position=non_pad)
                        predicted = predicted.view(-1, predicted.size(-1))
                        loss = label_smoothing_loss(predicted,
                                                    trg_sequences_target,
                                                    args.pad_id)

                    scaler.scale(loss).backward()
                    scaler.unscale_(optimizer)
                    clip_grad_norm_(model.parameters(), args.clip_grad_norm)
                    scaler.step(optimizer)
                    scaler.update()

                    if args.scheduler in ['constant', 'warmup']:
                        scheduler.step()
                    if args.scheduler == 'reduce_train':
                        scheduler.step(loss)

                    # Print loss value only training
                    if i == 0 or freq == args.print_freq or i == len(
                            dataloader_dict['train']):
                        acc = (predicted.max(dim=1)[1] == trg_sequences_target
                               ).sum() / len(trg_sequences_target)
                        iter_log = "[Epoch:%03d][%03d/%03d] train_loss:%03.3f | train_acc:%03.2f%% | learning_rate:%1.6f | spend_time:%02.2fmin" % \
                            (epoch, i, len(dataloader_dict['train']),
                            loss.item(), acc*100, optimizer.param_groups[0]['lr'],
                            (time() - start_time_e) / 60)
                        write_log(logger, iter_log)
                        freq = 0
                    freq += 1

                # Validation
                if phase == 'valid':
                    with torch.no_grad():
                        predicted = model(src,
                                          trg[:, :-1],
                                          tgt_mask,
                                          non_pad_position=non_pad)
                        loss = F.cross_entropy(predicted, trg_sequences_target)
                    val_loss += loss.item()
                    val_acc += (predicted.max(dim=1)[1] == trg_sequences_target
                                ).sum() / len(trg_sequences_target)
                    if args.scheduler == 'reduce_valid':
                        scheduler.step(val_loss)
                    if args.scheduler == 'lambda':
                        scheduler.step()

            if phase == 'valid':
                val_loss /= len(dataloader_dict[phase])
                val_acc /= len(dataloader_dict[phase])
                write_log(logger, 'Validation Loss: %3.3f' % val_loss)
                write_log(logger,
                          'Validation Accuracy: %3.2f%%' % (val_acc * 100))
                if val_acc > best_val_acc:
                    write_log(logger, 'Checkpoint saving...')
                    torch.save(
                        {
                            'epoch': epoch,
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'scheduler': scheduler.state_dict(),
                            'scaler': scaler.state_dict()
                        }, f'checkpoint_{args.parallel}.pth.tar')
                    best_val_acc = val_acc
                    best_epoch = epoch
                else:
                    else_log = f'Still {best_epoch} epoch accuracy({round(best_val_acc.item()*100, 2)})% is better...'
                    write_log(logger, else_log)

    # 3) Print results
    print(f'Best Epoch: {best_epoch}')
    print(f'Best Accuracy: {round(best_val_acc.item(), 2)}')
예제 #19
0
파일: optimizer.py 프로젝트: grimoire/mmcv
    class Fp16OptimizerHook(OptimizerHook):
        """FP16 optimizer hook (using PyTorch's implementation).

        If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend,
        to take care of the optimization procedure.

        Args:
            loss_scale (float | str | dict): Scale factor configuration.
                If loss_scale is a float, static loss scaling will be used with
                the specified scale. If loss_scale is a string, it must be
                'dynamic', then dynamic loss scaling will be used.
                It can also be a dict containing arguments of GradScalar.
                Defaults to 512. For Pytorch >= 1.6, mmcv uses official
                implementation of GradScaler. If you use a dict version of
                loss_scale to create GradScaler, please refer to:
                https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler
                for the parameters.

        Examples:
            >>> loss_scale = dict(
            ...     init_scale=65536.0,
            ...     growth_factor=2.0,
            ...     backoff_factor=0.5,
            ...     growth_interval=2000
            ... )
            >>> optimizer_hook = Fp16OptimizerHook(loss_scale=loss_scale)
        """
        def __init__(self,
                     grad_clip=None,
                     coalesce=True,
                     bucket_size_mb=-1,
                     loss_scale=512.,
                     distributed=True):
            self.grad_clip = grad_clip
            self.coalesce = coalesce
            self.bucket_size_mb = bucket_size_mb
            self.distributed = distributed
            self._scale_update_param = None
            if loss_scale == 'dynamic':
                self.loss_scaler = GradScaler()
            elif isinstance(loss_scale, float):
                self._scale_update_param = loss_scale
                self.loss_scaler = GradScaler(init_scale=loss_scale)
            elif isinstance(loss_scale, dict):
                self.loss_scaler = GradScaler(**loss_scale)
            else:
                raise ValueError('loss_scale must be of type float, dict, or '
                                 f'"dynamic", got {loss_scale}')

        def before_run(self, runner):
            """Preparing steps before Mixed Precision Training."""
            # wrap model mode to fp16
            wrap_fp16_model(runner.model)
            # resume from state dict
            if 'fp16' in runner.meta and 'loss_scaler' in runner.meta['fp16']:
                scaler_state_dict = runner.meta['fp16']['loss_scaler']
                self.loss_scaler.load_state_dict(scaler_state_dict)

        def copy_grads_to_fp32(self, fp16_net, fp32_weights):
            """Copy gradients from fp16 model to fp32 weight copy."""
            for fp32_param, fp16_param in zip(fp32_weights,
                                              fp16_net.parameters()):
                if fp16_param.grad is not None:
                    if fp32_param.grad is None:
                        fp32_param.grad = fp32_param.data.new(
                            fp32_param.size())
                    fp32_param.grad.copy_(fp16_param.grad)

        def copy_params_to_fp16(self, fp16_net, fp32_weights):
            """Copy updated params from fp32 weight copy to fp16 model."""
            for fp16_param, fp32_param in zip(fp16_net.parameters(),
                                              fp32_weights):
                fp16_param.data.copy_(fp32_param.data)

        def after_train_iter(self, runner):
            """Backward optimization steps for Mixed Precision Training. For
            dynamic loss scaling, please refer to
            https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.

            1. Scale the loss by a scale factor.
            2. Backward the loss to obtain the gradients.
            3. Unscale the optimizer’s gradient tensors.
            4. Call optimizer.step() and update scale factor.
            5. Save loss_scaler state_dict for resume purpose.
            """
            # clear grads of last iteration
            runner.model.zero_grad()
            runner.optimizer.zero_grad()

            self.loss_scaler.scale(runner.outputs['loss']).backward()
            self.loss_scaler.unscale_(runner.optimizer)
            # grad clip
            if self.grad_clip is not None:
                grad_norm = self.clip_grads(runner.model.parameters())
                if grad_norm is not None:
                    # Add grad norm to the logger
                    runner.log_buffer.update({'grad_norm': float(grad_norm)},
                                             runner.outputs['num_samples'])
            # backward and update scaler
            self.loss_scaler.step(runner.optimizer)
            self.loss_scaler.update(self._scale_update_param)

            # save state_dict of loss_scaler
            runner.meta.setdefault(
                'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
예제 #20
0
class Trainer:

    def __init__(self, config: DictConfig, model: FlyModel, name: str = "task1", *args, **kwargs):
        """
        Args:
            config: FlyConfig dictionary
            model: must be FlyModel
            dataloader_fn: a Callable function which returns dataloaders
        """
        logger.info("TrainerLoop is initializing!")
        if not isinstance(model, FlyModel):
            logger.warn("model is not defined as FlyModel")
        self.config = config
        self.model = model
        self.name = name

        # class properties
        self.rank = None
        self.local_rank = None
        self.node_rank = None
        self.world_size = None
        self.distributed_training = None
        self.device = None
        self.fp16 = config.fp16
        self.gradient_accumulation_batches = config.gradient_accumulation_batches
        self.callback_handler = None
        self.optimizers = []
        self.schedulers = []

        self.init_distributed_environment()

        # Model is sent to GPU or CPU
        self.init_device()
        # self.optimizers, self.schedulers = self.configure_optimizers()

        self.model = move_to_device(self.model, self.device)
        self.model.device = self.device
        self.init_fp16()

        if self.distributed_training:
            self.init_distributed_model(self.model)

        # make sure the model has access to trainer info
        self.model.set_trainer(self)

        self.callback_handler = CallbackHandler(config,
                                                trainer=self,
                                                callbacks=[],
                                                verbose=config.logging.level == "DEBUG")

        # Configure all callbacks
        self.configure_callbacks()
        self.callback_handler.fire_event(Events.INITIALIZE)

    def init_distributed_environment(self):
        # For distributed
        self.rank = int(os.environ.get("RANK", 0))
        self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
        self.world_size = int(os.environ.get("WORLD_SIZE", 1))
        self.distributed_training = (self.world_size > 1)

        # TODO: add error message when num_gpus is set, but distributed training is False here

        if self.distributed_training and not torch.distributed.is_initialized():
            torch.distributed.init_process_group(backend='nccl', init_method='env://')
            assert torch.distributed.is_initialized()

        if self.distributed_training and not torch.distributed.is_initialized():
            self.node_rank = os.environ.get("NODE_RANK", "N/A")
            logger.info(
                f"Initialized Rank:{torch.distributed.get_rank()} Locak-rank: {self.local_rank} on Node:{self.node_rank} Node-name:{socket.gethostname()}"
            )

    def init_device(self):
        # set cuda device
        if self.config.num_gpus_per_node > 0:
            torch.cuda.set_device(self.local_rank)
            self.device = torch.device("cuda", self.local_rank)
        else:
            self.device = torch.device("cpu")

    def init_fp16(self):
        if self.config.num_gpus_per_node == 0:
            raise NotImplementedError("For mixed precision training, you need to use GPU!")
        self.loss_scaler = GradScaler()

    def init_training_constants(self):
        self.total_num_update_steps = int(self.config.total_num.update_steps)
        self.total_num_batches = self.total_num_update_steps * int(self.gradient_accumulation_batches)
        self.total_num_epochs = int(self.config.total_num.epochs)

        # check if training in epoch or update_steps
        if self.total_num_update_steps < 0 and self.total_num_epochs < 0:
            raise NotImplementedError("config.total_num.updated_steps must be larger than 0")
        elif self.total_num_update_steps > 0 and self.total_num_epochs > 0:
            raise NotImplementedError(
                "Please only set either config.total_num.updated_steps or config.total_num.epochs greater than 0")
        elif self.total_num_update_steps > 0 and self.total_num_epochs < 0:
            self.training_in_epoch = False
        elif self.total_num_update_steps < 0 and self.total_num_epochs > 0:
            self.training_in_epoch = True

        # get the number of batches in the dataloader for one epoch
        try:
            self.epoch_num_batches = len(self.train_dataloader)
        except TypeError:
            logger.warning("Cannot determine the length of train_dataloader!")
            self.epoch_num_batches = None

        if self.training_in_epoch:
            if self.epoch_num_batches is not None:
                self.total_num_batches = self.epoch_num_batches * self.total_num_epochs
                self.total_num_update_steps = self.total_num_batches // self.gradient_accumulation_batches
                self.epoch_num_update_steps = self.epoch_num_batches // self.gradient_accumulation_batches
            else:
                # this is set to wait until the epoch finishes first
                self.total_num_update_steps = sys.maxsize

    def configure_optimizers(self, total_num_update_steps=None, optimizers=None, schedulers=None):
        if optimizers is not None and schedulers is not None:
            self.optimizers, self.schedulers = optimizers, schedulers
        elif total_num_update_steps is not None:
            self.optimizers, self.schedulers = self.model.configure_optimizers(total_num_update_steps)
        else:
            raise ValueError("Please provide the correct argument!")
        return self.optimizers, self.schedulers

    def configure_callbacks(self):
        # Resume callback runs for all ranks
        if self.config.resume.enabled:
            self.resume_callback = Resume(self.config)
            self.add_callback(self.resume_callback)

        self.log_callback = TrainLogger(self.config)
        self.add_callback(self.log_callback)

        self.eval_callback = Evaluation(self.config)
        self.add_callback(self.eval_callback)

        # For logging and inference, use rank 0 by default
        if self.rank == 0:
            if self.config.console:
                self.console_callback = Console(self.config)
                self.add_callback(self.console_callback)

            if self.config.checkpointing.enabled:
                self.checkpoint_callback = Checkpoint(self.config)
                self.add_callback(self.checkpoint_callback)

    def init_distributed_model(self, model):
        """
        Default distributed training uses reducer for simplicity. 
        """
        # Distributed training (should be after apex fp16 initialization)
        self.reducer = Reducer(model)
        # for param in self.model.parameters():
        #     dist.broadcast(param.data, 0)

    def train(self,
              train_dataloader,
              validation_dataloader=None,
              test_dataloader=None,
              configure_optimizers=True,
              name=None,
              *args,
              **kwargs):
        self.total_num_update_steps = 0
        self.total_num_batches = 0
        self.total_num_epochs = 0
        self.epoch_num_batches = 0
        self.global_batch_count = 0
        self.global_step_count = 0
        self.epochs_trained = 0
        self.local_step_count = 0

        self.train_dataloader = train_dataloader
        self.validation_dataloader = validation_dataloader
        self.test_dataloader = test_dataloader

        self.init_training_constants()

        if configure_optimizers or len(self.optimizers) == 0:
            self.configure_optimizers(self.total_num_update_steps)

        if name is not None:
            self.name = name

        # Training begins
        self.callback_handler.fire_event(Events.TRAIN_BEGIN)

        while True:
            self.callback_handler.fire_event(Events.EPOCH_BEGIN)
            self.train_epoch()
            self.callback_handler.fire_event(Events.EPOCH_END)
            self.epochs_trained += 1

            if self.training_in_epoch:
                if self.epochs_trained >= self.total_num_epochs:
                    break
            else:
                if self.global_step_count < self.total_num_update_steps:
                    continue
                else:
                    break

        # Training ends
        self.callback_handler.fire_event(Events.TRAIN_END)

    def train_epoch(self):
        self.optimizer = self.optimizers[0]
        self.scheduler = self.schedulers[0]

        self.local_step_count = 0

        if self.train_dataloader is None:
            return

        for batch in self.train_dataloader:
            self.callback_handler.fire_event(Events.BATCH_BEGIN)

            batch = move_to_device(batch, self.device)
            output = self.backward_batch(batch)

            # Update the model
            if (self.global_batch_count + 1) % self.gradient_accumulation_batches == 0:
                # Update the model with optimizer
                self.step_update(self.model, self.optimizer, self.scheduler)
                self.global_step_count += 1
                self.local_step_count += 1

            self.callback_handler.fire_event(Events.BATCH_END)

            if self.global_step_count >= self.total_num_update_steps:
                break

            self.global_batch_count += 1

    def backward_batch(self, batch):
        self.model.train()
        with torch.cuda.amp.autocast(self.fp16):
            output = self.model(batch)

        # get the loss from output
        if hasattr(output, "loss"):
            loss = output.loss
        elif isinstance(output, dict):
            loss = output["loss"]

        if self.gradient_accumulation_batches > 1:
            loss = loss / self.gradient_accumulation_batches

        self.loss_backward(loss)
        return output

    def step_update(self, model, optimizer, scheduler=None):
        """
            self.loss_scaler is defined in `configure_fp16`
        """
        self.callback_handler.fire_event(Events.STEP_BEGIN)
        # collect gradient
        if self.distributed_training:
            self.reducer.reduce()

        gradient_clip = self.config.optimization.max_gradient_norm
        # Gradient Clipping
        if gradient_clip > 0:
            if self.fp16:
                self.loss_scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip)
        # Update the model
        if self.fp16:
            self.loss_scaler.step(optimizer)
            self.loss_scaler.update()
        else:
            optimizer.step()
        # Step learning rate
        if scheduler:
            scheduler.step()
        # Gradient to zero
        optimizer.zero_grad()
        self.callback_handler.fire_event(Events.STEP_END)

    def loss_backward(self, loss):
        self.callback_handler.fire_event(Events.BACKWARD_BEGIN)
        # Loss backward
        if self.fp16:
            self.loss_scaler.scale(loss).backward()
        else:
            loss.backward()
        self.callback_handler.fire_event(Events.BACKWARD_END)

    def validate(self, dataloader):
        # Start Validation
        self.model.reset_evaluation_metrics()
        self.callback_handler.fire_event(Events.VALIDATE_BEGIN)
        self.model.validation_loop(dataloader)
        self.callback_handler.fire_event(Events.VALIDATE_END)

    def test(self, dataloader):
        # Start Testing
        self.model.reset_evaluation_metrics()
        self.callback_handler.fire_event(Events.TEST_BEGIN)
        self.model.test_loop(dataloader)
        self.callback_handler.fire_event(Events.TEST_END)

    def set_model_state(self, model_state_dict):
        self.model.load_state_dict(model_state_dict)

    def get_model_state(self):
        return self.model.state_dict()

    def set_trainer_state(self, trainer_state_dict):
        self.epochs_trained = trainer_state_dict["epochs_trained"]
        self.global_step_count = trainer_state_dict["global_step_count"]
        self.local_step_count = trainer_state_dict["local_step_count"]

        # Resume the training state
        if self.config.resume.resume:
            # Scheduler States
            if self.config.resume.resume_scheduler:
                for idx, scheduler in enumerate(self.schedulers):
                    try:
                        scheduler.load_state_dict(trainer_state_dict["schedulers_state_dict"][idx])
                    except:
                        if self.rank == 0:
                            logger.warning(f"Cannot Load Scheduler {idx}'s State!")

            if self.config.resume.resume_optimizer:
                for idx, optimizer in enumerate(self.optimizers):
                    try:
                        optimizer.load_state_dict(trainer_state_dict["optimizers_state_dict"][idx])
                    except:
                        if self.rank == 0:
                            logger.warning(f"Cannot Load Optimizer {idx}'s State!")

            # save amp states
            if self.fp16:
                self.loss_scaler.load_state_dict(trainer_state_dict["amp_state_dict"])

            # Random States
            if self.config.resume.resume_rng_state:
                torch.set_rng_state(trainer_state_dict["cpu_rng_state"])
                trainer_state_dict["cuda_rng_state"] = trainer_state_dict["cuda_rng_state"][:torch.cuda.device_count()]
                torch.cuda.set_rng_state_all(trainer_state_dict["cuda_rng_state"])

            # All Callbacks
            for callback in self.callback_handler.callbacks:
                try:
                    callback.load_state_dict(trainer_state_dict[str(type(callback))])
                except:
                    logger.error(f"{type(callback)} seems not to exist in the checkpoint state!")

    def get_trainer_state(self):
        trainer_state_dict = {
            "epochs_trained": self.epochs_trained,
            "global_step_count": self.global_step_count,
            "local_step_count": self.local_step_count,
            "optimizers_state_dict": [optimizer.state_dict() for optimizer in self.optimizers],
            "schedulers_state_dict": [scheduler.state_dict() for scheduler in self.schedulers],
            "cpu_rng_state": torch.get_rng_state(),
            "cuda_rng_state": torch.cuda.get_rng_state_all(),
        }

        # save amp states
        if self.fp16:
            trainer_state_dict["amp_state_dict"] = self.loss_scaler.state_dict()

        # All Callbacks
        for callback in self.callback_handler.callbacks:
            trainer_state_dict[str(type(callback))] = callback.state_dict()

        return trainer_state_dict

    def add_callback(self, callback: Callback):
        self.callback_handler.add_callback(callback)


# def get_lr(optimizer):
#     for param_group in optimizer.param_groups:
#         return param_group['lr']

# def get_log_variable(x):
#     if isinstance(x, torch.Tensor):
#         x = x.detach()
#         return x.item()
#     else:
#         raise NotImplementedError
예제 #21
0
def main(args):
    # ensures that weight initializations are all the same
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    logging = utils.Logger(args.global_rank, args.save)
    writer = utils.Writer(args.global_rank, args.save)

    # Get data loaders.
    train_queue, valid_queue, num_classes, _ = datasets.get_loaders(args)
    args.num_total_iter = len(train_queue) * args.epochs
    warmup_iters = len(train_queue) * args.warmup_epochs
    swa_start = len(train_queue) * (args.epochs - 1)

    arch_instance = utils.get_arch_cells(args.arch_instance)

    model = AutoEncoder(args, writer, arch_instance)
    model = model.cuda()

    logging.info('args = %s', args)
    logging.info('param size = %fM ', utils.count_parameters_in_M(model))
    logging.info('groups per scale: %s, total_groups: %d',
                 model.groups_per_scale, sum(model.groups_per_scale))

    if args.fast_adamax:
        # Fast adamax has the same functionality as torch.optim.Adamax, except it is faster.
        cnn_optimizer = Adamax(model.parameters(),
                               args.learning_rate,
                               weight_decay=args.weight_decay,
                               eps=1e-3)
    else:
        cnn_optimizer = torch.optim.Adamax(model.parameters(),
                                           args.learning_rate,
                                           weight_decay=args.weight_decay,
                                           eps=1e-3)

    cnn_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        cnn_optimizer,
        float(args.epochs - args.warmup_epochs - 1),
        eta_min=args.learning_rate_min)
    grad_scalar = GradScaler(2**10)

    num_output = utils.num_output(args.dataset, args)
    bpd_coeff = 1. / np.log(2.) / num_output

    # if load
    checkpoint_file = os.path.join(args.save, 'checkpoint.pt')
    if args.cont_training:
        logging.info('loading the model.')
        checkpoint = torch.load(checkpoint_file, map_location='cpu')
        init_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        model = model.cuda()
        cnn_optimizer.load_state_dict(checkpoint['optimizer'])
        grad_scalar.load_state_dict(checkpoint['grad_scalar'])
        cnn_scheduler.load_state_dict(checkpoint['scheduler'])
        global_step = checkpoint['global_step']
    else:
        global_step, init_epoch = 0, 0

    for epoch in range(init_epoch, args.epochs):
        # update lrs.
        if args.distributed:
            train_queue.sampler.set_epoch(global_step + args.seed)
            valid_queue.sampler.set_epoch(0)

        if epoch > args.warmup_epochs:
            cnn_scheduler.step()

        # Logging.
        logging.info('epoch %d', epoch)

        # Training.
        train_nelbo, global_step = train(train_queue, model, cnn_optimizer,
                                         grad_scalar, global_step,
                                         warmup_iters, writer, logging)
        logging.info('train_nelbo %f', train_nelbo)
        writer.add_scalar('train/nelbo', train_nelbo, global_step)

        model.eval()
        # generate samples less frequently
        eval_freq = 1 if args.epochs <= 50 else 20
        if epoch % eval_freq == 0 or epoch == (args.epochs - 1):
            with torch.no_grad():
                num_samples = 16
                n = int(np.floor(np.sqrt(num_samples)))
                for t in [0.7, 0.8, 0.9, 1.0]:
                    logits = model.sample(num_samples, t)
                    output = model.decoder_output(logits)
                    output_img = output.mean if isinstance(
                        output, torch.distributions.bernoulli.Bernoulli
                    ) else output.sample(t)
                    output_tiled = utils.tile_image(output_img, n)
                    writer.add_image('generated_%0.1f' % t, output_tiled,
                                     global_step)

            valid_neg_log_p, valid_nelbo = test(valid_queue,
                                                model,
                                                num_samples=10,
                                                args=args,
                                                logging=logging)
            logging.info('valid_nelbo %f', valid_nelbo)
            logging.info('valid neg log p %f', valid_neg_log_p)
            logging.info('valid bpd elbo %f', valid_nelbo * bpd_coeff)
            logging.info('valid bpd log p %f', valid_neg_log_p * bpd_coeff)
            writer.add_scalar('val/neg_log_p', valid_neg_log_p, epoch)
            writer.add_scalar('val/nelbo', valid_nelbo, epoch)
            writer.add_scalar('val/bpd_log_p', valid_neg_log_p * bpd_coeff,
                              epoch)
            writer.add_scalar('val/bpd_elbo', valid_nelbo * bpd_coeff, epoch)

        save_freq = int(np.ceil(args.epochs / 100))
        if epoch % save_freq == 0 or epoch == (args.epochs - 1):
            if args.global_rank == 0:
                logging.info('saving the model.')
                torch.save(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'optimizer': cnn_optimizer.state_dict(),
                        'global_step': global_step,
                        'args': args,
                        'arch_instance': arch_instance,
                        'scheduler': cnn_scheduler.state_dict(),
                        'grad_scalar': grad_scalar.state_dict()
                    }, checkpoint_file)

    # Final validation
    valid_neg_log_p, valid_nelbo = test(valid_queue,
                                        model,
                                        num_samples=1000,
                                        args=args,
                                        logging=logging)
    logging.info('final valid nelbo %f', valid_nelbo)
    logging.info('final valid neg log p %f', valid_neg_log_p)
    writer.add_scalar('val/neg_log_p', valid_neg_log_p, epoch + 1)
    writer.add_scalar('val/nelbo', valid_nelbo, epoch + 1)
    writer.add_scalar('val/bpd_log_p', valid_neg_log_p * bpd_coeff, epoch + 1)
    writer.add_scalar('val/bpd_elbo', valid_nelbo * bpd_coeff, epoch + 1)
    writer.close()
예제 #22
0
def main_worker(rank, size, args_in):
    global args
    args = args_in
    is_root = rank == 0

    dist.init_process_group(backend='nccl', init_method=f"tcp://localhost:{args.port}",
                            world_size=size, rank=rank)

    """ Config writer, seed and device """
    CheckpointFunction.use_amp = args.amp

    writer = config_summary_writer(is_root=is_root, output_dir=args.output_dir)
    seed = args.seed + rank
    seed_all(seed)
    device = config_device()

    torch.backends.cudnn.benchmark = True

    """ Load Dataloaders  """

    if args.ckpt is not None:
        gpt_ckpt = torch.load(args.ckpt, map_location=device)

        if is_root:
            print(f"Loading GPT from checkpoint {args.ckpt} with loss {gpt_ckpt['best_loss']}")
        dset_configs = gpt_ckpt['dset_configs']

        # overwrite
        args.dataset = dset_configs['dataset']
        args.resolution = dset_configs['resolution']
    else:
        gpt_ckpt = None
        dset_configs = dict(dataset=args.dataset, resolution=args.resolution,
                            n_frames=args.n_frames)

    train_loader, test_loader, dset = get_distributed_loaders(
        dset_configs=dset_configs, batch_size=args.batch_size, seed=seed
    )
    if is_root:
        print(f"dset loader n_batch: train = {len(train_loader)}, test = {len(test_loader)}")

    """ Load VQ-VAE """
    vqvae_ckpt = args.vqvae_ckpt if gpt_ckpt is None else gpt_ckpt['vqvae_ckpt']
    if is_root:
        print(f'Loading VQ-VAE from {vqvae_ckpt}')

    vqvae_ckpt_loaded = torch.load(vqvae_ckpt, map_location=device)
    vqvae, vq_hp = load_model(
        ckpt=vqvae_ckpt_loaded,
        device=device, freeze_model=True, cond_types=tuple()
    )
    del vqvae_ckpt_loaded

    latent_shape = vqvae.latent_shape
    quantized_shape = vqvae.quantized_shape
    if is_root:
        print('latent shape', latent_shape)
        print('quantized shape', quantized_shape)
        print('total latents', np.prod(latent_shape))

    """ Config cond_types"""

    if gpt_ckpt is not None:
        cond_hp = gpt_ckpt['cond_hp']
    else:
        cond_hp = dict(
            n_cond_frames=args.n_cond_frames,
            class_cond=args.class_cond,
            cond_init_configs=dict(
                type='enc_attn',
                model='resnet_v1',
                resnet_dim=576,
                resnet_depth=34,
                resnet_output_shape=(1, 16, 16),
                width_multiplier=1,
            ),
        )

    def load_prior(layer_ckpt):
        """ Check consistency """
        layer_cond_types, _ = config_cond_types(
            cond_hp=layer_ckpt['cond_hp'], dset=dset)
        # freeze all previous priors, not the current one
        layer_prior, layer_hp = load_model(
            ckpt=layer_ckpt, device=device, freeze_model=False,
            cond_types=layer_cond_types)
        layer_codebook = vqvae.codebook
        return layer_prior, layer_hp, layer_codebook

    def inputs_fn(batch):
        with torch.no_grad():
            videos = batch['video'].to(device, non_blocking=True)  # (b, c, t, h, w)

            cond = []
            if cond_hp['n_cond_frames'] > 0:
                cond_frames = videos[:, :, :cond_hp['n_cond_frames']]
                cond.append(cond_frames)
            if cond_hp['class_cond']:
                cond.append(batch['label'].to(device, non_blocking=True))

            quantized, encodings = vqvae.encode(x=videos, no_flatten=True)

            # latent_shape = (t, h, w, l)
            quantized = shift_dim(quantized, 1, -1)  # (b, d, t, h, w, l) -> (b, t, h, w, l, d)  # channel first -> last
            encodings = encodings.long()

            cond = tuple(cond)
            return dict(encodings=encodings, quantized=quantized, cond=cond,
                        decode_step=None, decode_idx=None)

    cond_types, cond_hp = config_cond_types(
        cond_hp=cond_hp, dset=dset
    )

    if is_root:
        print('cond_types', [(c.name, c.type, c.out_size) for c in cond_types])

    """ Load GPT snapshot, if any """
    if gpt_ckpt is not None:
        prior, hp, codebook = load_prior(layer_ckpt=gpt_ckpt)

        best_loss = gpt_ckpt['best_loss']

        optimizer = optim.Adam(prior.parameters(), lr=args.lr)
        optimizer.load_state_dict(gpt_ckpt['optimizer'])
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, args.total_iters)
        scheduler.load_state_dict(gpt_ckpt['scheduler'])
        scaler = GradScaler()
        scaler.load_state_dict(gpt_ckpt['scaler'])

        epoch_start = gpt_ckpt['epoch']
        iteration_start = gpt_ckpt['iteration'] + 1

        del gpt_ckpt
    else:
        # TODO: use (self_gen_n_embd*num_self_gen_in_use,) i.e. concat, or use below i.e. sum up y_gen?
        prior, hp = config_model(
            configs_str=args.cfg,
            shape=latent_shape,
            in_features=vq_hp['embedding_dim'],
            n_vocab=vq_hp['codes_per_book'],
            cond_types=cond_types,
        )
        prior = prior.to(device)
        codebook = vqvae.codebook

        optimizer = optim.Adam(prior.parameters(), lr=args.lr)
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, args.total_iters)
        scaler = GradScaler()
        best_loss = float('inf')

        epoch_start = 0
        iteration_start = 1
    # find_unused_parameters needs to be False for gradient checkpointing to work
    prior = DistributedDataParallel(prior, device_ids=[rank], find_unused_parameters=False,
                                    broadcast_buffers=False)

    if is_root:
        for cond_net in prior.cond_nets:
            print('cond_net size with grad', sum(p.numel() for p in cond_net.parameters() if p.requires_grad))
            print('cond_net size', sum(p.numel() for p in cond_net.parameters()))

    if is_root:
        if args.amp:
            print('Training with AMP')

    # to be saved to model checkpoints
    default_ckpt_dict = {
        'dset_configs': dset_configs,
        'cond_hp': cond_hp,
        'hp': hp,
        'vqvae_ckpt': vqvae_ckpt,
    }

    def get_ckpt_dict(**ckpt_dict):
        return {**ckpt_dict, **default_ckpt_dict}

    if is_root:
        total_parameters = sum([np.prod(p.shape) for p in prior.parameters() if p.requires_grad])
        print('model size: prior params count with grads = {}'.format(total_parameters))

    train_loader = InfDataLoader(train_loader, epoch_start)

    # training and validation, all in latent space
    train_for = functools.partial(
        train,
        train_loader=train_loader,
        inputs_fn=inputs_fn,
        prior=prior,
        optimizer=optimizer,
        scheduler=scheduler,
        scaler=scaler,
        writer=writer,
        is_root=is_root,
        size=size,
        device=device,
    )
    validate_for = functools.partial(
        validate,
        test_loader=test_loader,
        inputs_fn=inputs_fn,
        prior=prior,
        writer=writer,
        is_root=is_root,
        size=size,
        device=device,
    )

    # end to end sampling in pixel space
    sample_fn = functools.partial(
        sample,
        cond_hp=cond_hp,
        vae=vqvae,
        prior=prior,
        codebook=codebook,
        device=device,
        temperature=args.temperature,
        rank=rank,
        size=size,
    )  # takes in n_samples, batch, returns samples of size min(n_samples, batch_size * size (roughly, not verified))
    # tensor (n, c, t, h, w) in [0, 1]

    save_samples_for = functools.partial(
        save_samples,
        sample_fn=sample_fn,
        loader=test_loader,
        writer=writer,
        is_root=is_root,
        size=size,
    )

    iteration = iteration_start
    log_mem_usage, log_time_usage = True, True
    time_start = time.time()

    while iteration <= args.total_iters:
        train_loss, iteration = train_for(iteration=iteration)  # average gen_loss

        if iteration % args.test_every == 0:
            test_loss = validate_for(iteration=iteration)
            if is_root:
                writer.add_scalar('test/gen_loss_gap', test_loss - train_loss, iteration * args.batch_size)
            is_best = test_loss < best_loss
            best_loss = min(test_loss, best_loss)

            ckpt_dict = get_ckpt_dict(
                epoch=train_loader.epoch,
                iteration=iteration,
                n_obs=iteration * args.batch_size,
                state_dict=prior.module.state_dict(),
                optimizer=optimizer.state_dict(),
                scheduler=scheduler.state_dict(),
                scaler=scaler.state_dict(),
                best_loss=best_loss,
            )
            save_checkpoint(ckpt_dict, is_best=is_best, is_root=is_root,
                            output_dir=args.output_dir)

        if iteration % args.generate_every == 0 and save_samples_for:
            save_samples_for(iteration=iteration)

        iteration += 1

    if is_root:
        print(f'Final iteration: {iteration}, best loss: {best_loss}')
        print(f'Logs saved under {args.output_dir}')
        writer.close()
예제 #23
0
try:
    G.load_state_dict(torch.load('./saved_models/AEI_G_latest.pth',
                                 map_location=torch.device('cpu')),
                      strict=False)
    D.load_state_dict(torch.load('./saved_models/AEI_D_latest.pth',
                                 map_location=torch.device('cpu')),
                      strict=False)
    opt_G.load_state_dict(
        torch.load('./saved_models/AEI_optG_latest.pth',
                   map_location=torch.device('cpu')))
    opt_D.load_state_dict(
        torch.load('./saved_models/AEI_optD_latest.pth',
                   map_location=torch.device('cpu')))
    scaler.load_state_dict(
        torch.load('./saved_models/AEI_scaler_latest.pth',
                   map_location=torch.device('cpu')))
except Exception as e:
    print(e)
try:
    with open('./saved_models/AEI_niter.pkl', 'rb') as f:
        min_iter = pickle.load(f)
except Exception as e:
    print(e)
writer = SummaryWriter('runs/FaceShifterAEInet', purge_step=min_iter)

TrainFaceSources = [
    '/home/olivier/Images/FaceShifter/celeba-256/',
    '/home/olivier/Images/FaceShifter/Perso/',
    '/home/olivier/Images/FaceShifter/VGGFaceTrain/',
    '/home/olivier/Images/FaceShifter/FFHQ/',
예제 #24
0
def main(cfg: DictConfig) -> None:
    if cfg.trainer.print_torch_setup is True:
        print_torch_setup()

    if cfg.trainer.seed is not None:
        random.seed(cfg.trainer.seed)
        torch.manual_seed(cfg.trainer.seed)
        torch.backends.cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    assert torch.cuda.is_available(), 'This code requires a GPU to train'
    torch.backends.cudnn.benchmark = True
    assert cfg.trainer.output_dir, 'You need to specify an output directory'

    mkdir(cfg.trainer.output_dir)
    experiment_name = time.strftime("%Y%m%d-%H%M%S")
    print(f'The current experiment will be tracked as {experiment_name}')
    output_dir = os.path.join(cfg.trainer.output_dir, experiment_name)
    print(f'Results will be saved in {output_dir}')
    writer = SummaryWriter(output_dir)

    # this is just a workaround for now
    # hparams logging to a file and as text into tensorboard
    # it is certainly not perfect... :/
    hparams = flatten_dict(OmegaConf.to_container(cfg, resolve=True))
    hparams_as_str = [
        str(k) + ' >>> ' + str(v) + '\n' for k, v in hparams.items()
    ]
    # TODO: this seems to not work properly!
    # writer.add_hparams(hparams, metric_dict={'acc': 1}, run_name=experiment_name)
    with open(os.path.join(output_dir, 'hparams.txt'), 'w',
              encoding='utf-8') as hparams_file:
        for line in hparams_as_str:
            hparams_file.write(line)
    writer.add_text('hparams', '\r\n'.join(hparams_as_str), global_step=0)

    device = torch.device(cfg.trainer.device)
    assert device.type == 'cuda', 'Only GPU based training is supported'

    dataset = instantiate(cfg.dataset.train)

    assert cfg.dataset.val_split is not None, 'Handling a separate validation set is not implemented as of now!'
    train_size = int((1 - cfg.dataset.val_split) * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size])

    train_sampler_weights = dataset.make_weights_for_dataset_sampling(
        train_dataset)
    sampler = WeightedRandomSampler(
        train_sampler_weights,
        num_samples=cfg.dataset.train_samples_per_epoch,
        replacement=True)
    train_collate_fn = dataset.get_collate_fn(
        mode='train', channels_last=cfg.trainer.channels_last)
    train_dataloader = instantiate(cfg.dataloader.train,
                                   dataset=train_dataset,
                                   collate_fn=train_collate_fn,
                                   sampler=sampler)

    val_collate_fn = dataset.get_collate_fn(
        mode='val', channels_last=cfg.trainer.channels_last)
    val_dataloader = instantiate(cfg.dataloader.val,
                                 dataset=val_dataset,
                                 collate_fn=val_collate_fn)

    # this handler moves a batch to the GPU as uint8, casts it to a float after transferring it
    # and normalizes the images
    to_device_handler = ToDeviceFunction(device=device,
                                         mean=cfg.dataset.mean,
                                         std=cfg.dataset.std)

    # the prefetch loader prefetches the next batch onto the GPU which makes up a couple
    # of percent in the training loop
    train_dataloader = PrefetchLoader(loader=train_dataloader,
                                      to_device_handler=to_device_handler)

    # val_dataloader = PrefetchLoader(loader=val_dataloader,
    #                                 to_device_handler=to_device_handler)

    model = instantiate(cfg.models.model, device=device).to(device)

    if cfg.trainer.channels_last is True:
        model = model.to(memory_format=torch.channels_last)

    if cfg.trainer.anomaly_detection is True:
        torch.autograd.set_detect_anomaly(mode=True)

    params_to_optimize = [{
        "params": [p for p in model.parameters() if p.requires_grad]
    }]

    optimizer = instantiate(cfg.optimizer, params_to_optimize)

    scaler = GradScaler(enabled=cfg.trainer.amp)

    if cfg.trainer.resume is not None:
        if os.path.isfile(cfg.trainer.resume):
            print("Trying to load checkpoint '{}'".format(cfg.trainer.resume))

            if cfg.trainer.from_u2net_checkpoint is True:
                checkpoint = torch.load(cfg.trainer.resume,
                                        map_location=device)
                model.load_state_dict(checkpoint)
            else:
                checkpoint = torch.load(cfg.trainer.resume,
                                        map_location=device)
                model.load_state_dict(checkpoint['model'])

                if cfg.trainer.weights_only is False:
                    cfg.trainer.start_epoch = checkpoint['epoch']
                    optimizer.load_state_dict(checkpoint['optimizer'])
                    scaler.load_state_dict(checkpoint['scaler'])

            print(
                f'Loaded checkpoint {cfg.trainer.resume}. Resuming training at epoch {cfg.trainer.start_epoch}'
            )
        else:
            warnings.warn(f'Checkpoint f{cfg.trainer.resume} not found!')

    print("Start training...")
    start_time = time.time()

    if cfg.trainer.dry_run is True:
        print("Doing dry run, running val on train dataset...")
        # validate_one_epoch(writer, model, train_dataloader, device, 0, cfg.trainer.print_freq)
        return

    for epoch in range(cfg.trainer.start_epoch, cfg.trainer.epochs):
        train_one_epoch(writer, device, model, optimizer, scaler,
                        train_dataloader, epoch, cfg)
        # validate_one_epoch(writer, model, val_dataloader, epoch, cfg)

        checkpoint = {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scaler': scaler.state_dict(),
            'epoch': epoch,
            'cfg': cfg
        }
        save_on_master(checkpoint,
                       os.path.join(output_dir, 'model_{}.pth'.format(epoch)))
        save_on_master(checkpoint, os.path.join(output_dir, 'checkpoint.pth'))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
예제 #25
0
class DynamicIterBasedRunner(IterBasedRunner):
    """Dynamic Iterbased Runner.

    In this Dynamic Iterbased Runner, we will pass the ``reducer`` to the
    ``train_step`` so that the models can be trained with dynamic architecture.
    More details and clarification can be found in this [tutorial](docs/tutorials/ddp_train_gans.md).  # noqa

    Args:
        is_dynamic_ddp (bool, optional): Whether to adopt the dynamic ddp.
            Defaults to False.
        pass_training_status (bool, optional): Whether to pass the training
            status. Defaults to False.
        fp16_loss_scaler (dict | None, optional): Config for fp16 GradScaler
            from ``torch.cuda.amp``. Defaults to None.
        use_apex_amp (bool, optional): Whether to use apex.amp to start mixed
            precision training. Defaults to False.
    """
    def __init__(self,
                 *args,
                 is_dynamic_ddp=False,
                 pass_training_status=False,
                 fp16_loss_scaler=None,
                 use_apex_amp=False,
                 **kwargs):
        super().__init__(*args, **kwargs)
        if is_module_wrapper(self.model):
            _model = self.model.module
        else:
            _model = self.model

        self.is_dynamic_ddp = is_dynamic_ddp
        self.pass_training_status = pass_training_status

        # add a flag for checking if `self.optimizer` comes from `_model`
        self.optimizer_from_model = False
        # add support for optimizer is None.
        # sanity check for whether `_model` contains self-defined optimizer
        if hasattr(_model, 'optimizer'):
            assert self.optimizer is None, (
                'Runner and model cannot contain optimizer at the same time.')
            self.optimizer_from_model = True
            self.optimizer = _model.optimizer

        # add fp16 grad scaler, using pytorch official GradScaler
        self.with_fp16_grad_scaler = False
        if fp16_loss_scaler is not None:
            self.loss_scaler = GradScaler(**fp16_loss_scaler)
            self.with_fp16_grad_scaler = True
            mmcv.print_log('Use FP16 grad scaler in Training', 'mmgen')

        # flag to use amp in apex (NVIDIA)
        self.use_apex_amp = use_apex_amp

    def call_hook(self, fn_name):
        """Call all hooks.

        Args:
            fn_name (str): The function name in each hook to be called, such as
                "before_train_epoch".
        """
        for hook in self._hooks:
            if hasattr(hook, fn_name):
                getattr(hook, fn_name)(self)

    def train(self, data_loader, **kwargs):
        if is_module_wrapper(self.model):
            _model = self.model.module
        else:
            _model = self.model
        self.model.train()
        self.mode = 'train'
        # check if self.optimizer from model and track it
        if self.optimizer_from_model:
            self.optimizer = _model.optimizer

        self.data_loader = data_loader
        self._epoch = data_loader.epoch
        self.call_hook('before_fetch_train_data')
        data_batch = next(self.data_loader)
        self.call_hook('before_train_iter')

        # prepare input args for train_step
        # running status
        if self.pass_training_status:
            running_status = dict(iteration=self.iter, epoch=self.epoch)
            kwargs['running_status'] = running_status
        # ddp reducer for tracking dynamic computational graph
        if self.is_dynamic_ddp:
            kwargs.update(dict(ddp_reducer=self.model.reducer))

        if self.with_fp16_grad_scaler:
            kwargs.update(dict(loss_scaler=self.loss_scaler))

        if self.use_apex_amp:
            kwargs.update(dict(use_apex_amp=True))

        outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)

        # the loss scaler should be updated after ``train_step``
        if self.with_fp16_grad_scaler:
            self.loss_scaler.update()

        # further check for the cases where the optimizer is built in
        # `train_step`.
        if self.optimizer is None:
            if hasattr(_model, 'optimizer'):
                self.optimizer_from_model = True
                self.optimizer = _model.optimizer

        # check if self.optimizer from model and track it
        if self.optimizer_from_model:
            self.optimizer = _model.optimizer
        if not isinstance(outputs, dict):
            raise TypeError('model.train_step() must return a dict')
        if 'log_vars' in outputs:
            self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
        self.outputs = outputs
        self.call_hook('after_train_iter')
        self._inner_iter += 1
        self._iter += 1

    def run(self, data_loaders, workflow, max_iters=None, **kwargs):
        """Start running.

        Args:
            data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
                and validation.
            workflow (list[tuple]): A list of (phase, iters) to specify the
                running order and iterations. E.g, [('train', 10000),
                ('val', 1000)] means running 10000 iterations for training and
                1000 iterations for validation, iteratively.
        """
        assert isinstance(data_loaders, list)
        assert mmcv.is_list_of(workflow, tuple)
        assert len(data_loaders) == len(workflow)
        if max_iters is not None:
            warnings.warn(
                'setting max_iters in run is deprecated, '
                'please set max_iters in runner_config', DeprecationWarning)
            self._max_iters = max_iters
        assert self._max_iters is not None, (
            'max_iters must be specified during instantiation')

        work_dir = self.work_dir if self.work_dir is not None else 'NONE'
        self.logger.info('Start running, host: %s, work_dir: %s',
                         get_host_info(), work_dir)
        self.logger.info('workflow: %s, max: %d iters', workflow,
                         self._max_iters)
        self.call_hook('before_run')

        iter_loaders = [IterLoader(x, self) for x in data_loaders]

        self.call_hook('before_epoch')

        while self.iter < self._max_iters:
            for i, flow in enumerate(workflow):
                self._inner_iter = 0
                mode, iters = flow
                if not isinstance(mode, str) or not hasattr(self, mode):
                    raise ValueError(
                        'runner has no method named "{}" to run a workflow'.
                        format(mode))
                iter_runner = getattr(self, mode)
                for _ in range(iters):
                    if mode == 'train' and self.iter >= self._max_iters:
                        break
                    iter_runner(iter_loaders[i], **kwargs)

        time.sleep(1)  # wait for some hooks like loggers to finish
        self.call_hook('after_epoch')
        self.call_hook('after_run')

    def resume(self,
               checkpoint,
               resume_optimizer=True,
               resume_loss_scaler=True,
               map_location='default'):
        """Resume model from checkpoint.

        Args:
            checkpoint (str): Checkpoint to resume from.
            resume_optimizer (bool, optional): Whether resume the optimizer(s)
                if the checkpoint file includes optimizer(s). Default to True.
            resume_loss_scaler (bool, optional): Whether to resume the loss
                scaler (GradScaler) from ``torch.cuda.amp``. Defaults to True.
            map_location (str, optional): Same as :func:`torch.load`.
                Default to 'default'.
        """
        if map_location == 'default':
            device_id = torch.cuda.current_device()
            checkpoint = self.load_checkpoint(
                checkpoint,
                map_location=lambda storage, loc: storage.cuda(device_id))
        else:
            checkpoint = self.load_checkpoint(checkpoint,
                                              map_location=map_location)

        self._epoch = checkpoint['meta']['epoch']
        self._iter = checkpoint['meta']['iter']
        self._inner_iter = checkpoint['meta']['iter']
        if 'optimizer' in checkpoint and resume_optimizer:
            if isinstance(self.optimizer, Optimizer):
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            elif isinstance(self.optimizer, dict):
                for k in self.optimizer.keys():
                    self.optimizer[k].load_state_dict(
                        checkpoint['optimizer'][k])
            else:
                raise TypeError(
                    'Optimizer should be dict or torch.optim.Optimizer '
                    f'but got {type(self.optimizer)}')

        if 'loss_scaler' in checkpoint and resume_loss_scaler:
            self.loss_scaler.load_state_dict(checkpoint['loss_scaler'])

        if self.use_apex_amp:
            from apex import amp
            amp.load_state_dict(checkpoint['amp'])

        self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}')

    def save_checkpoint(self,
                        out_dir,
                        filename_tmpl='iter_{}.pth',
                        meta=None,
                        save_optimizer=True,
                        create_symlink=True):
        """Save checkpoint to file.

        Args:
            out_dir (str): Directory to save checkpoint files.
            filename_tmpl (str, optional): Checkpoint file template.
                Defaults to 'iter_{}.pth'.
            meta (dict, optional): Metadata to be saved in checkpoint.
                Defaults to None.
            save_optimizer (bool, optional): Whether save optimizer.
                Defaults to True.
            create_symlink (bool, optional): Whether create symlink to the
                latest checkpoint file. Defaults to True.
        """
        if meta is None:
            meta = dict(iter=self.iter + 1, epoch=self.epoch + 1)
        elif isinstance(meta, dict):
            meta.update(iter=self.iter + 1, epoch=self.epoch + 1)
        else:
            raise TypeError(
                f'meta should be a dict or None, but got {type(meta)}')
        if self.meta is not None:
            meta.update(self.meta)

        filename = filename_tmpl.format(self.iter + 1)
        filepath = osp.join(out_dir, filename)
        optimizer = self.optimizer if save_optimizer else None
        _loss_scaler = self.loss_scaler if self.with_fp16_grad_scaler else None
        save_checkpoint(self.model,
                        filepath,
                        optimizer=optimizer,
                        loss_scaler=_loss_scaler,
                        save_apex_amp=self.use_apex_amp,
                        meta=meta)
        # in some environments, `os.symlink` is not supported, you may need to
        # set `create_symlink` to False
        if create_symlink:
            dst_file = osp.join(out_dir, 'latest.pth')
            if platform.system() != 'Windows':
                mmcv.symlink(filename, dst_file)
            else:
                shutil.copy(filepath, dst_file)

    def register_lr_hook(self, lr_config):
        if lr_config is None:
            return

        if isinstance(lr_config, dict):
            assert 'policy' in lr_config
            policy_type = lr_config.pop('policy')
            # If the type of policy is all in lower case, e.g., 'cyclic',
            # then its first letter will be capitalized, e.g., to be 'Cyclic'.
            # This is for the convenient usage of Lr updater.
            # Since this is not applicable for `
            # CosineAnnealingLrUpdater`,
            # the string will not be changed if it contains capital letters.
            if policy_type == policy_type.lower():
                policy_type = policy_type.title()
            hook_type = policy_type + 'LrUpdaterHook'
            lr_config['type'] = hook_type
            hook = mmcv.build_from_cfg(lr_config, HOOKS)
        else:
            hook = lr_config
        self.register_hook(hook)
예제 #26
0
class BaseTrainer:
    def __init__(self, dist, rank, config, resume, only_validation, model,
                 loss_function, optimizer):
        self.color_tool = colorful
        self.color_tool.use_style("solarized")

        model = DistributedDataParallel(model.to(rank), device_ids=[rank])
        self.model = model
        self.optimizer = optimizer
        self.loss_function = loss_function

        # DistributedDataParallel (DDP)
        self.rank = rank
        self.dist = dist

        # Automatic mixed precision (AMP)
        self.use_amp = config["meta"]["use_amp"]
        self.scaler = GradScaler(enabled=self.use_amp)

        # Acoustics
        self.acoustic_config = config["acoustics"]

        # Supported STFT
        n_fft = self.acoustic_config["n_fft"]
        hop_length = self.acoustic_config["hop_length"]
        win_length = self.acoustic_config["win_length"]

        self.torch_stft = partial(stft,
                                  n_fft=n_fft,
                                  hop_length=hop_length,
                                  win_length=win_length)
        self.torch_istft = partial(istft,
                                   n_fft=n_fft,
                                   hop_length=hop_length,
                                   win_length=win_length)
        self.librosa_stft = partial(librosa.stft,
                                    n_fft=n_fft,
                                    hop_length=hop_length,
                                    win_length=win_length)
        self.librosa_istft = partial(librosa.istft,
                                     hop_length=hop_length,
                                     win_length=win_length)

        # Trainer.train in the config
        self.train_config = config["trainer"]["train"]
        self.epochs = self.train_config["epochs"]
        self.save_checkpoint_interval = self.train_config[
            "save_checkpoint_interval"]
        self.clip_grad_norm_value = self.train_config["clip_grad_norm_value"]
        assert self.save_checkpoint_interval >= 1, "Check the 'save_checkpoint_interval' parameter in the config. It should be large than one."

        # Trainer.validation in the config
        self.validation_config = config["trainer"]["validation"]
        self.validation_interval = self.validation_config[
            "validation_interval"]
        self.save_max_metric_score = self.validation_config[
            "save_max_metric_score"]
        assert self.validation_interval >= 1, "Check the 'validation_interval' parameter in the config. It should be large than one."

        # Trainer.visualization in the config
        self.visualization_config = config["trainer"]["visualization"]

        # In the 'train.py' file, if the 'resume' item is 'True', we will update the following args:
        self.start_epoch = 1
        self.best_score = -np.inf if self.save_max_metric_score else np.inf
        self.save_dir = Path(config["meta"]["save_dir"]).expanduser().absolute(
        ) / config["meta"]["experiment_name"]
        self.checkpoints_dir = self.save_dir / "checkpoints"
        self.logs_dir = self.save_dir / "logs"

        if resume:
            self._resume_checkpoint()

        # Debug validation, which skips training
        self.only_validation = only_validation

        if config["meta"]["preloaded_model_path"]:
            self._preload_model(Path(config["preloaded_model_path"]))

        if self.rank == 0:
            prepare_empty_dir([self.checkpoints_dir, self.logs_dir],
                              resume=resume)

            self.writer = SummaryWriter(self.logs_dir.as_posix(),
                                        max_queue=5,
                                        flush_secs=30)
            self.writer.add_text(
                tag="Configuration",
                text_string=f"<pre>  \n{toml.dumps(config)}  \n</pre>",
                global_step=1)

            print(self.color_tool.cyan("The configurations are as follows: "))
            print(self.color_tool.cyan("=" * 40))
            print(self.color_tool.cyan(toml.dumps(config)[:-1]))  # except "\n"
            print(self.color_tool.cyan("=" * 40))

            with open(
                (self.save_dir /
                 f"{time.strftime('%Y-%m-%d %H:%M:%S')}.toml").as_posix(),
                    "w") as handle:
                toml.dump(config, handle)

            self._print_networks([self.model])

    def _preload_model(self, model_path):
        """
        Preload model parameters (in "*.tar" format) at the start of experiment.

        Args:
            model_path (Path): The file path of the *.tar file
        """
        model_path = model_path.expanduser().absolute()
        assert model_path.exists(
        ), f"The file {model_path.as_posix()} is not exist. please check path."

        model_checkpoint = torch.load(model_path.as_posix(),
                                      map_location="cpu")
        self.model.load_state_dict(model_checkpoint["model"], strict=False)
        self.model.to(self.rank)

        if self.rank == 0:
            print(
                f"Model preloaded successfully from {model_path.as_posix()}.")

    def _resume_checkpoint(self):
        """
        Resume the experiment from the latest checkpoint.
        """
        latest_model_path = self.checkpoints_dir.expanduser().absolute(
        ) / "latest_model.tar"
        assert latest_model_path.exists(
        ), f"{latest_model_path} does not exist, can not load latest checkpoint."

        # Make sure all processes (GPUs) do not start loading before the saving is finished.
        # see https://stackoverflow.com/questions/59760328/how-does-torch-distributed-barrier-work
        self.dist.barrier()

        # Load it on the CPU and later use .to(device) on the model
        # Maybe slightly slow than use map_location="cuda:<...>"
        # https://stackoverflow.com/questions/61642619/pytorch-distributed-data-parallel-confusion
        checkpoint = torch.load(latest_model_path.as_posix(),
                                map_location="cpu")

        self.start_epoch = checkpoint["epoch"] + 1
        self.best_score = checkpoint["best_score"]
        self.optimizer.load_state_dict(checkpoint["optimizer"])
        self.scaler.load_state_dict(checkpoint["scaler"])

        if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
            self.model.module.load_state_dict(checkpoint["model"])
        else:
            self.model.load_state_dict(checkpoint["model"])

        # self.model.to(self.rank)

        if self.rank == 0:
            print(
                f"Model checkpoint loaded. Training will begin at {self.start_epoch} epoch."
            )

    def _save_checkpoint(self, epoch, is_best_epoch=False):
        """
        Save checkpoint to "<save_dir>/<config name>/checkpoints" directory, which consists of:
            - epoch
            - best metric score in historical epochs
            - optimizer parameters
            - model parameters

        Args:
            is_best_epoch (bool): In the current epoch, if the model get a best metric score (is_best_epoch=True),
                                the checkpoint of model will be saved as "<save_dir>/checkpoints/best_model.tar".
        """
        print(f"\t Saving {epoch} epoch model checkpoint...")

        state_dict = {
            "epoch": epoch,
            "best_score": self.best_score,
            "optimizer": self.optimizer.state_dict(),
            "scaler": self.scaler.state_dict()
        }

        if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
            state_dict["model"] = self.model.module.state_dict()
        else:
            state_dict["model"] = self.model.state_dict()

        # Saved in "latest_model.tar"
        # Contains all checkpoint information, including the optimizer parameters, the model parameters, etc.
        # New checkpoint will overwrite the older one.
        torch.save(state_dict,
                   (self.checkpoints_dir / "latest_model.tar").as_posix())

        # "model_{epoch_number}.pth"
        # Contains only model.
        torch.save(state_dict["model"],
                   (self.checkpoints_dir /
                    f"model_{str(epoch).zfill(4)}.pth").as_posix())

        # If the model get a best metric score (means "is_best_epoch=True") in the current epoch,
        # the model checkpoint will be saved as "best_model.tar"
        # The newer best-scored checkpoint will overwrite the older one.
        if is_best_epoch:
            print(
                self.color_tool.red(
                    f"\t Found a best score in the {epoch} epoch, saving..."))
            torch.save(state_dict,
                       (self.checkpoints_dir / "best_model.tar").as_posix())

    def _is_best_epoch(self, score, save_max_metric_score=True):
        """
        Check if the current model got the best metric score
        """
        if save_max_metric_score and score >= self.best_score:
            self.best_score = score
            return True
        elif not save_max_metric_score and score <= self.best_score:
            self.best_score = score
            return True
        else:
            return False

    @staticmethod
    def _print_networks(models: list):
        print(
            f"This project contains {len(models)} models, the number of the parameters is: "
        )

        params_of_all_networks = 0
        for idx, model in enumerate(models, start=1):
            params_of_network = 0
            for param in model.parameters():
                params_of_network += param.numel()

            print(f"\tNetwork {idx}: {params_of_network / 1e6} million.")
            params_of_all_networks += params_of_network

        print(
            f"The amount of parameters in the project is {params_of_all_networks / 1e6} million."
        )

    def _set_models_to_train_mode(self):
        self.model.train()

    def _set_models_to_eval_mode(self):
        self.model.eval()

    def spec_audio_visualization(self,
                                 noisy,
                                 enhanced,
                                 clean,
                                 name,
                                 epoch,
                                 mark=""):
        self.writer.add_audio(f"{mark}_Speech/{name}_Noisy",
                              noisy,
                              epoch,
                              sample_rate=16000)
        self.writer.add_audio(f"{mark}_Speech/{name}_Enhanced",
                              enhanced,
                              epoch,
                              sample_rate=16000)
        self.writer.add_audio(f"{mark}_Speech/{name}_Clean",
                              clean,
                              epoch,
                              sample_rate=16000)

        # Visualize the spectrogram of noisy speech, clean speech, and enhanced speech
        noisy_mag, _ = librosa.magphase(
            self.librosa_stft(noisy, n_fft=320, hop_length=160,
                              win_length=320))
        enhanced_mag, _ = librosa.magphase(
            self.librosa_stft(enhanced,
                              n_fft=320,
                              hop_length=160,
                              win_length=320))
        clean_mag, _ = librosa.magphase(
            self.librosa_stft(clean, n_fft=320, hop_length=160,
                              win_length=320))
        fig, axes = plt.subplots(3, 1, figsize=(6, 6))
        for k, mag in enumerate([noisy_mag, enhanced_mag, clean_mag]):
            axes[k].set_title(f"mean: {np.mean(mag):.3f}, "
                              f"std: {np.std(mag):.3f}, "
                              f"max: {np.max(mag):.3f}, "
                              f"min: {np.min(mag):.3f}")
            librosa.display.specshow(librosa.amplitude_to_db(mag),
                                     cmap="magma",
                                     y_axis="linear",
                                     ax=axes[k],
                                     sr=16000)
        plt.tight_layout()
        self.writer.add_figure(f"{mark}_Spectrogram/{name}", fig, epoch)

    def metrics_visualization(self,
                              noisy_list,
                              clean_list,
                              enhanced_list,
                              metrics_list,
                              epoch,
                              num_workers=10,
                              mark=""):
        """
        Get metrics on validation dataset by paralleling.

        Notes:
            1. You can register other metrics, but STOI and WB_PESQ metrics must be existence. These two metrics are
             used for checking if the current epoch is a "best epoch."
            2. If you want to use a new metric, you must register it in "util.metrics" file.
        """
        assert "STOI" in metrics_list and "WB_PESQ" in metrics_list, "'STOI' and 'WB_PESQ' must be existence."

        # Check if the metric is registered in "util.metrics" file.
        for i in metrics_list:
            assert i in metrics.REGISTERED_METRICS.keys(
            ), f"{i} is not registered, please check 'util.metrics' file."

        stoi_mean = 0.0
        wb_pesq_mean = 0.0
        for metric_name in metrics_list:
            score_on_noisy = Parallel(n_jobs=num_workers)(
                delayed(metrics.REGISTERED_METRICS[metric_name])(ref, est)
                for ref, est in zip(clean_list, noisy_list))
            score_on_enhanced = Parallel(n_jobs=num_workers)(
                delayed(metrics.REGISTERED_METRICS[metric_name])(ref, est)
                for ref, est in zip(clean_list, enhanced_list))

            # Add the mean value of the metric to tensorboard
            mean_score_on_noisy = np.mean(score_on_noisy)
            mean_score_on_enhanced = np.mean(score_on_enhanced)
            self.writer.add_scalars(f"{mark}_Validation/{metric_name}", {
                "Noisy": mean_score_on_noisy,
                "Enhanced": mean_score_on_enhanced
            }, epoch)

            if metric_name == "STOI":
                stoi_mean = mean_score_on_enhanced

            if metric_name == "WB_PESQ":
                wb_pesq_mean = transform_pesq_range(mean_score_on_enhanced)

        return (stoi_mean + wb_pesq_mean) / 2

    def train(self):
        for epoch in range(self.start_epoch, self.epochs + 1):
            if self.rank == 0:
                print(
                    self.color_tool.yellow(
                        f"{'=' * 15} {epoch} epoch {'=' * 15}"))
                print("[0 seconds] Begin training...")

            # [debug validation] Only run validation (only use the first GPU (process))
            # inference + calculating metrics + saving checkpoints
            if self.only_validation and self.rank == 0:
                self._set_models_to_eval_mode()
                metric_score = self._validation_epoch(epoch)

                if self._is_best_epoch(
                        metric_score,
                        save_max_metric_score=self.save_max_metric_score):
                    self._save_checkpoint(epoch, is_best_epoch=True)

                # Skip the following regular training, saving checkpoints, and validation
                continue

            # Regular training
            timer = ExecutionTime()
            self._set_models_to_train_mode()
            self._train_epoch(epoch)

            #  Regular save checkpoints
            if self.rank == 0 and self.save_checkpoint_interval != 0 and (
                    epoch % self.save_checkpoint_interval == 0):
                self._save_checkpoint(epoch)

            # Regular validation
            if self.rank == 0 and (epoch % self.validation_interval == 0):
                print(
                    f"[{timer.duration()} seconds] Training has finished, validation is in progress..."
                )

                self._set_models_to_eval_mode()
                metric_score = self._validation_epoch(epoch)

                if self._is_best_epoch(
                        metric_score,
                        save_max_metric_score=self.save_max_metric_score):
                    self._save_checkpoint(epoch, is_best_epoch=True)

            print(f"[{timer.duration()} seconds] This epoch is finished.")

    def _train_epoch(self, epoch):
        raise NotImplementedError

    def _validation_epoch(self, epoch):
        raise NotImplementedError
예제 #27
0
def train(model: Model,
          state: dict,
          train_data_path: str,
          train_rgb_json: str,
          val_data_path: str,
          val_rgb_json: str,
          transform_file: str,
          growing_parameters: dict,
          lr: float,
          iterations: int,
          val_iterations: int,
          verbose: bool,
          train_segment_masks_path: str = '',
          val_segment_masks_path: str = '',
          lambda_ccl=0.0,
          loss_type='L2',
          ccl_version='linear',
          alpha=5,
          gamma=.5,
          regularization_l2: float = 0.,
          warmup=5000,
          milestones=[],
          optimizer_name: str = 'sgd',
          print_every: int = 250,
          debug=False):
    model.train()
    torch.backends.cudnn.benchmark = True

    if debug:
        print_every = 10

    sparse_growing_parameters = load_growing_parameters(growing_parameters)
    filled_growing_parameters = fill_growing_parameters(
        sparse_growing_parameters, iterations)

    assert os.path.isfile(transform_file)
    sys.path.insert(0, os.path.dirname(transform_file))
    transforms = __import__(
        os.path.splitext(os.path.basename(transform_file))[0])

    model_dir = os.path.dirname(state['path'])

    writer = SummaryWriter(log_dir=os.path.join(model_dir, 'logs'))

    if loss_type == 'L2':
        criterion = L2Loss(weighted=False)
    elif loss_type == 'L2W':
        criterion = L2Loss(weighted=True, alpha=alpha, gamma=gamma)
    elif loss_type == 'L1':
        criterion = L1Loss(weighted=False)
    elif loss_type == 'L1W':
        criterion = L1Loss(weighted=True, alpha=alpha, gamma=gamma)
    elif loss_type == 'L2+CCL':
        criterion = L2CCLoss(lambda_ccl=lambda_ccl, ccl_version=ccl_version)
    elif loss_type == 'L2W+CCL':
        criterion = L2CCLoss(lambda_ccl=lambda_ccl,
                             ccl_version=ccl_version,
                             weighted=True,
                             alpha=alpha,
                             gamma=gamma)
    elif loss_type == 'L2+CCL-gt':
        criterion = L2CCLoss(lambda_ccl=lambda_ccl,
                             ccl_version=ccl_version,
                             ccl_target='gt',
                             weighted=False)
    elif loss_type == 'L2W+CCL-gt':
        criterion = L2CCLoss(lambda_ccl=lambda_ccl,
                             ccl_version=ccl_version,
                             ccl_target='gt',
                             weighted=True,
                             alpha=alpha,
                             gamma=gamma)
    elif loss_type == 'L1+CCL':
        criterion = L1CCLoss(lambda_ccl=lambda_ccl, ccl_version=ccl_version)
    elif loss_type == 'L1W+CCL':
        criterion = L1CCLoss(lambda_ccl=lambda_ccl,
                             ccl_version=ccl_version,
                             weighted=True,
                             alpha=alpha,
                             gamma=gamma)
    elif loss_type == 'L1+CCL-gt':
        criterion = L1CCLoss(lambda_ccl=lambda_ccl,
                             ccl_version=ccl_version,
                             ccl_target='gt',
                             weighted=False)
    elif loss_type == 'L1W+CCL-gt':
        criterion = L1CCLoss(lambda_ccl=lambda_ccl,
                             ccl_version=ccl_version,
                             ccl_target='gt',
                             weighted=True,
                             alpha=alpha,
                             gamma=gamma)
    else:
        raise NotImplementedError()

    if torch.cuda.is_available():
        model = model.cuda()
        criterion = criterion.cuda()

    if optimizer_name == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=lr,
                               weight_decay=regularization_l2)
    elif optimizer_name == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=lr,
                              weight_decay=regularization_l2,
                              momentum=0.9)
    else:
        raise NotImplementedError(f'Optimizer {optimizer_name} not available')
    if 'optimizer' in state:
        print('loading optimizer...')
        optimizer.load_state_dict(state['optimizer'])

    scaler = GradScaler(enabled=True)
    if 'scaler' in state:
        print('loading scaler...')
        scaler.load_state_dict(state['scaler'])

    def schedule(train_iter):
        if warmup and train_iter <= warmup:
            return 0.9 * train_iter / warmup + 0.1
        return 0.1**len([m for m in milestones if m <= train_iter])

    scheduler = LambdaLR(optimizer, schedule)
    if 'scheduler' in state:
        print('loading scheduler...')
        scheduler.load_state_dict(state['scheduler'])
    iteration = state.get('iteration', 0)

    if iteration >= iterations:
        print('Training already done.')
        return

    if train_segment_masks_path or val_segment_masks_path:
        trainset = ImagenetColorSegmentData(train_data_path,
                                            train_segment_masks_path,
                                            rgb_json=train_rgb_json,
                                            transform=None,
                                            transform_l=to_tensor_l,
                                            transform_ab=to_tensor_ab)
        testset = ImagenetColorSegmentData(
            val_data_path,
            val_segment_masks_path,
            rgb_json=val_rgb_json,
            transform=transforms.get_val_transform(1024),
            transform_l=to_tensor_l,
            transform_ab=to_tensor_ab)
    else:
        trainset = ImagenetData(train_data_path,
                                rgb_json=train_rgb_json,
                                transform=None,
                                transform_l=to_tensor_l,
                                transform_ab=to_tensor_ab)
        testset = ImagenetData(val_data_path,
                               rgb_json=val_rgb_json,
                               transform=transforms.get_val_transform(1024),
                               transform_l=to_tensor_l,
                               transform_ab=to_tensor_ab)

    trainset_infer = ImagenetData(train_data_path,
                                  rgb_json=train_rgb_json,
                                  transform=transforms.get_val_transform(1024),
                                  transform_l=to_tensor_l,
                                  transform_ab=to_tensor_ab,
                                  training=False)
    testset_infer = ImagenetData(val_data_path,
                                 rgb_json=val_rgb_json,
                                 transform=transforms.get_val_transform(1024),
                                 transform_l=to_tensor_l,
                                 transform_ab=to_tensor_ab,
                                 training=False)

    sampler = SavableShuffleSampler(trainset, shuffle=not debug)
    if 'sampler' in state:
        print('loading sampler...')
        sampler.load_state_dict(state['sampler'])

    if len(sampler) > len(trainset):
        sampler = SavableShuffleSampler(trainset, shuffle=not debug)
        print('recreate the sampler, trainset changed...')

    print(f'        Loss: {loss_type}')
    print(criterion)
    print(
        f'   Optimizer: {optimizer.__class__.__name__} (LR:{optimizer.param_groups[0]["lr"]:.6f})'
    )
    print(f'   Iteration: {iteration}/{iterations}')
    print(f'      Warmup: {warmup}')
    print(f'  Milestones: {milestones}')
    print(f'     Growing: {sparse_growing_parameters}')
    print(f'   Traindata: {len(trainset)} images')
    print(f'    Testdata: {len(testset)} images')
    print(f' Sampler idx: {sampler.index}')
    print(f'Current step: {scheduler._step_count}')

    batch_size, input_size = filled_growing_parameters[iteration]
    trainset.transform = transforms.get_transform(input_size[0])
    trainloader = get_trainloader(trainset, batch_size, sampler)

    running_psnr, img_per_sec = 0.0, 0.0
    running_loss, avg_running_loss = defaultdict(float), defaultdict(float)
    tic = time.time()
    changed_batch_size = True
    psnr = PSNR()
    pbar = tqdm(total=iterations, initial=iteration)

    if iteration == 0:
        for name, param in model.named_parameters():
            writer.add_histogram(name, param, global_step=iteration)

    while iteration < iterations:
        loss_str = ' - '.join(
            [f'{key}: {val:.5f} ' for key, val in avg_running_loss.items()])
        pbar.set_description(
            f'[Ep: {sampler.epoch} | B: {batch_size} | Im: {input_size[0]}x{input_size[1]}]  loss: {loss_str} - {img_per_sec:.2f} img/s'
        )
        for data in trainloader:
            if iteration in sparse_growing_parameters and not changed_batch_size:
                # change batch size and input size
                batch_size, input_size = sparse_growing_parameters[iteration]
                trainset.transform = transforms.get_transform(input_size[0])
                # recreate the loader, otherwise the transform is not propagated in multiprocessing to the workers
                trainloader = get_trainloader(trainset, batch_size, sampler)
                changed_batch_size = True
                break
            else:
                changed_batch_size = False

            if torch.cuda.is_available():
                data = tuple([el.cuda(non_blocking=True) for el in data])

            # get data
            if len(data) == 4:
                inputs, labels, segment_masks, _ = data

            else:
                inputs, labels = data

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            with autocast():
                outputs = model(inputs)
                crit_labels = [labels, segment_masks
                               ] if train_segment_masks_path else [labels]
                loss, loss_dict = criterion(outputs, *crit_labels)
                _psnr = psnr(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            scheduler.step()

            del outputs
            del inputs
            del labels
            del data

            # print statistics
            for k, v, in loss_dict.items():
                running_loss[k] += v.item()
            running_psnr += _psnr.item()

            iteration += 1

            if iteration % print_every == 0 or iteration == iterations:
                img_per_sec = print_every * batch_size / (time.time() - tic)

                for k, v in running_loss.items():
                    avg_running_loss[k] = running_loss[k] / print_every
                    writer.add_scalar(f'train/{k}',
                                      avg_running_loss[k],
                                      global_step=iteration)
                avg_running_psnr = running_psnr / print_every

                writer.add_scalar('train/PSNR',
                                  avg_running_psnr,
                                  global_step=iteration)

                writer.add_scalar('Performance/Images per second',
                                  img_per_sec,
                                  global_step=iteration)
                writer.add_scalar('Learning rate',
                                  optimizer.param_groups[0]['lr'],
                                  global_step=iteration)
                if loss_type in ['L1+CCL', 'L2+CCL']:
                    writer.add_scalar('Parameters/lambda CCL',
                                      lambda_ccl,
                                      global_step=iteration)
                loss_str = ' - '.join([
                    f'{key}: {val:.5} '
                    for key, val in avg_running_loss.items()
                ])
                pbar.set_description(
                    f'[Ep: {sampler.epoch} | B: {batch_size} | Im: {input_size[0]}x{input_size[1]}] loss: {loss_str} - {img_per_sec:.2f} img/s'
                )

                running_loss = defaultdict(float)
                running_psnr = 0.0
                state.update({
                    'iteration': iteration,
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'scaler': scaler.state_dict(),
                    'sampler': sampler.state_dict()
                })

                model.save(state, iteration)
                delete_older_then_n(state['path'], 10)

                tic = time.time()
            if iteration == iterations or iteration % val_iterations == 0:
                # run validation
                torch.backends.cudnn.benchmark = False
                model = model.eval()
                test_loader = DataLoader(testset,
                                         batch_size=1,
                                         shuffle=False,
                                         num_workers=8,
                                         pin_memory=True,
                                         prefetch_factor=1)
                with torch.no_grad():
                    metric_results = get_validation_metrics(
                        test_loader, model, criterion, ccl_version=ccl_version)
                for k, v in metric_results.items():
                    writer.add_scalar(f'validation/{k}',
                                      v,
                                      global_step=iteration)

                # images from validation
                predicted_images = infer(
                    model=model,
                    dataset=testset_infer,
                    target_path=os.path.join(model_dir,
                                             f'predictions-{iteration}'),
                    batch_size=1,
                    img_limit=20,
                    transform=transforms.get_val_transform(1024),
                    debug=True,
                    tensorboard=True)
                for i, img in enumerate(predicted_images):
                    writer.add_image(f'example-{i}',
                                     img,
                                     global_step=iteration,
                                     dataformats='HWC')

                # images from training
                predicted_images = infer(
                    model=model,
                    dataset=trainset_infer,
                    target_path=os.path.join(
                        model_dir, f'predictions-training-{iteration}'),
                    batch_size=1,
                    img_limit=20,
                    transform=transforms.get_val_transform(1024),
                    debug=True,
                    tensorboard=True)
                for i, img in enumerate(predicted_images):
                    writer.add_image(f'example-train-{i}',
                                     img,
                                     global_step=iteration,
                                     dataformats='HWC')

                for name, param in model.named_parameters():
                    writer.add_histogram(name, param, global_step=iteration)
                model = model.train()
                torch.backends.cudnn.benchmark = True
                tic = time.time()
            pbar.update(1)
            if iteration == iterations:
                break

    pbar.close()
    writer.close()
    print('Finished Training')
예제 #28
0
class Trainer(object):
    def __init__(
        self,
        diffusion_model,
        folder,
        *,
        ema_decay = 0.995,
        image_size = 128,
        train_batch_size = 32,
        train_lr = 2e-5,
        train_num_steps = 100000,
        gradient_accumulate_every = 2,
        amp = False,
        step_start_ema = 2000,
        update_ema_every = 10,
        save_and_sample_every = 1000,
        results_folder = './results'
    ):
        super().__init__()
        self.model = diffusion_model
        self.ema = EMA(ema_decay)
        self.ema_model = copy.deepcopy(self.model)
        self.update_ema_every = update_ema_every

        self.step_start_ema = step_start_ema
        self.save_and_sample_every = save_and_sample_every

        self.batch_size = train_batch_size
        self.image_size = diffusion_model.image_size
        self.gradient_accumulate_every = gradient_accumulate_every
        self.train_num_steps = train_num_steps

        self.ds = Dataset(folder, image_size)
        self.dl = cycle(data.DataLoader(self.ds, batch_size = train_batch_size, shuffle=True, pin_memory=True))
        self.opt = Adam(diffusion_model.parameters(), lr=train_lr)

        self.step = 0

        self.amp = amp
        self.scaler = GradScaler(enabled = amp)

        self.results_folder = Path(results_folder)
        self.results_folder.mkdir(exist_ok = True)

        self.reset_parameters()

    def reset_parameters(self):
        self.ema_model.load_state_dict(self.model.state_dict())

    def step_ema(self):
        if self.step < self.step_start_ema:
            self.reset_parameters()
            return
        self.ema.update_model_average(self.ema_model, self.model)

    def save(self, milestone):
        data = {
            'step': self.step,
            'model': self.model.state_dict(),
            'ema': self.ema_model.state_dict(),
            'scaler': self.scaler.state_dict()
        }
        torch.save(data, str(self.results_folder / f'model-{milestone}.pt'))

    def load(self, milestone):
        data = torch.load(str(self.results_folder / f'model-{milestone}.pt'))

        self.step = data['step']
        self.model.load_state_dict(data['model'])
        self.ema_model.load_state_dict(data['ema'])
        self.scaler.load_state_dict(data['scaler'])

    def train(self):
        while self.step < self.train_num_steps:
            for i in range(self.gradient_accumulate_every):
                data = next(self.dl).cuda()

                with autocast(enabled = self.amp):
                    loss = self.model(data)
                    self.scaler.scale(loss / self.gradient_accumulate_every).backward()

                print(f'{self.step}: {loss.item()}')

            self.scaler.step(self.opt)
            self.scaler.update()
            self.opt.zero_grad()

            if self.step % self.update_ema_every == 0:
                self.step_ema()

            if self.step != 0 and self.step % self.save_and_sample_every == 0:
                milestone = self.step // self.save_and_sample_every
                batches = num_to_groups(36, self.batch_size)
                all_images_list = list(map(lambda n: self.ema_model.sample(batch_size=n), batches))
                all_images = torch.cat(all_images_list, dim=0)
                all_images = (all_images + 1) * 0.5
                utils.save_image(all_images, str(self.results_folder / f'sample-{milestone}.png'), nrow = 6)
                self.save(milestone)

            self.step += 1

        print('training completed')
예제 #29
0
class Trainer():
    def __init__(self,
                 name='default',
                 results_dir='results',
                 models_dir='models',
                 base_dir='./',
                 optimizer="adam",
                 latent_dim=256,
                 image_size=128,
                 fmap_max=512,
                 transparent=False,
                 greyscale=False,
                 batch_size=4,
                 gp_weight=10,
                 gradient_accumulate_every=1,
                 attn_res_layers=[],
                 disc_output_size=5,
                 antialias=False,
                 lr=2e-4,
                 lr_mlp=1.,
                 ttur_mult=1.,
                 save_every=1000,
                 evaluate_every=1000,
                 trunc_psi=0.6,
                 aug_prob=None,
                 aug_types=['translation', 'cutout'],
                 dataset_aug_prob=0.,
                 calculate_fid_every=None,
                 is_ddp=False,
                 rank=0,
                 world_size=1,
                 log=False,
                 amp=False,
                 *args,
                 **kwargs):
        self.GAN_params = [args, kwargs]
        self.GAN = None

        self.name = name

        base_dir = Path(base_dir)
        self.base_dir = base_dir
        self.results_dir = base_dir / results_dir
        self.models_dir = base_dir / models_dir
        self.config_path = self.models_dir / name / '.config.json'

        assert is_power_of_two(
            image_size
        ), 'image size must be a power of 2 (64, 128, 256, 512, 1024)'
        assert all(
            map(is_power_of_two, attn_res_layers)
        ), 'resolution layers of attention must all be powers of 2 (16, 32, 64, 128, 256, 512)'

        self.optimizer = optimizer
        self.latent_dim = latent_dim
        self.image_size = image_size
        self.fmap_max = fmap_max
        self.transparent = transparent
        self.greyscale = greyscale

        assert (int(self.transparent) + int(self.greyscale)
                ) < 2, 'you can only set either transparency or greyscale'

        self.aug_prob = aug_prob
        self.aug_types = aug_types

        self.lr = lr
        self.ttur_mult = ttur_mult
        self.batch_size = batch_size
        self.gradient_accumulate_every = gradient_accumulate_every

        self.gp_weight = gp_weight

        self.evaluate_every = evaluate_every
        self.save_every = save_every
        self.steps = 0

        self.generator_top_k_gamma = 0.99
        self.generator_top_k_frac = 0.5

        self.attn_res_layers = attn_res_layers
        self.disc_output_size = disc_output_size
        self.antialias = antialias

        self.d_loss = 0
        self.g_loss = 0
        self.last_gp_loss = None
        self.last_recon_loss = None
        self.last_fid = None

        self.init_folders()

        self.loader = None
        self.dataset_aug_prob = dataset_aug_prob

        self.calculate_fid_every = calculate_fid_every

        self.is_ddp = is_ddp
        self.is_main = rank == 0
        self.rank = rank
        self.world_size = world_size

        self.syncbatchnorm = is_ddp

        self.amp = amp
        self.G_scaler = GradScaler(enabled=self.amp)
        self.D_scaler = GradScaler(enabled=self.amp)

    @property
    def image_extension(self):
        return 'jpg' if not self.transparent else 'png'

    @property
    def checkpoint_num(self):
        return floor(self.steps // self.save_every)

    def init_GAN(self):
        args, kwargs = self.GAN_params

        # set some global variables before instantiating GAN

        global norm_class
        global Blur

        norm_class = nn.SyncBatchNorm if self.syncbatchnorm else nn.BatchNorm2d
        Blur = nn.Identity if not self.antialias else Blur

        # handle bugs when
        # switching from multi-gpu back to single gpu

        if self.syncbatchnorm and not self.is_ddp:
            import torch.distributed as dist
            os.environ['MASTER_ADDR'] = 'localhost'
            os.environ['MASTER_PORT'] = '12355'
            dist.init_process_group('nccl', rank=0, world_size=1)

        # instantiate GAN

        self.GAN = LightweightGAN(optimizer=self.optimizer,
                                  lr=self.lr,
                                  latent_dim=self.latent_dim,
                                  attn_res_layers=self.attn_res_layers,
                                  image_size=self.image_size,
                                  ttur_mult=self.ttur_mult,
                                  fmap_max=self.fmap_max,
                                  disc_output_size=self.disc_output_size,
                                  transparent=self.transparent,
                                  greyscale=self.greyscale,
                                  rank=self.rank,
                                  *args,
                                  **kwargs)

        if self.is_ddp:
            ddp_kwargs = {
                'device_ids': [self.rank],
                'output_device': self.rank,
                'find_unused_parameters': True
            }

            self.G_ddp = DDP(self.GAN.G, **ddp_kwargs)
            self.D_ddp = DDP(self.GAN.D, **ddp_kwargs)
            self.D_aug_ddp = DDP(self.GAN.D_aug, **ddp_kwargs)

    def write_config(self):
        self.config_path.write_text(json.dumps(self.config()))

    def load_config(self):
        config = self.config(
        ) if not self.config_path.exists() else json.loads(
            self.config_path.read_text())
        self.image_size = config['image_size']
        self.transparent = config['transparent']
        self.syncbatchnorm = config['syncbatchnorm']
        self.disc_output_size = config['disc_output_size']
        self.greyscale = config.pop('greyscale', False)
        self.attn_res_layers = config.pop('attn_res_layers', [])
        self.optimizer = config.pop('optimizer', 'adam')
        self.fmap_max = config.pop('fmap_max', 512)
        del self.GAN
        self.init_GAN()

    def config(self):
        return {
            'image_size': self.image_size,
            'transparent': self.transparent,
            'greyscale': self.greyscale,
            'syncbatchnorm': self.syncbatchnorm,
            'disc_output_size': self.disc_output_size,
            'optimizer': self.optimizer,
            'attn_res_layers': self.attn_res_layers
        }

    def set_data_src(self, folder):
        self.dataset = ImageDataset(folder,
                                    self.image_size,
                                    transparent=self.transparent,
                                    greyscale=self.greyscale,
                                    aug_prob=self.dataset_aug_prob)
        sampler = DistributedSampler(self.dataset,
                                     rank=self.rank,
                                     num_replicas=self.world_size,
                                     shuffle=True) if self.is_ddp else None
        dataloader = DataLoader(
            self.dataset,
            num_workers=math.ceil(NUM_CORES / self.world_size),
            batch_size=math.ceil(self.batch_size / self.world_size),
            sampler=sampler,
            shuffle=not self.is_ddp,
            drop_last=True,
            pin_memory=True)
        self.loader = cycle(dataloader)

        # auto set augmentation prob for user if dataset is detected to be low
        num_samples = len(self.dataset)
        if not exists(self.aug_prob) and num_samples < 1e5:
            self.aug_prob = min(0.5, (1e5 - num_samples) * 3e-6)
            print(
                f'autosetting augmentation probability to {round(self.aug_prob * 100)}%'
            )

    def train(self):
        assert exists(
            self.loader
        ), 'You must first initialize the data source with `.set_data_src(<folder of images>)`'
        device = torch.device(f'cuda:{self.rank}')

        if not exists(self.GAN):
            self.init_GAN()

        self.GAN.train()
        total_disc_loss = torch.zeros([], device=device)
        total_gen_loss = torch.zeros([], device=device)

        batch_size = math.ceil(self.batch_size / self.world_size)

        image_size = self.GAN.image_size
        latent_dim = self.GAN.latent_dim

        aug_prob = default(self.aug_prob, 0)
        aug_types = self.aug_types
        aug_kwargs = {'prob': aug_prob, 'types': aug_types}

        G = self.GAN.G if not self.is_ddp else self.G_ddp
        D = self.GAN.D if not self.is_ddp else self.D_ddp
        D_aug = self.GAN.D_aug if not self.is_ddp else self.D_aug_ddp

        apply_gradient_penalty = self.steps % 4 == 0

        # amp related contexts and functions

        amp_context = autocast if self.amp else null_context

        # train discriminator
        self.GAN.D_opt.zero_grad()
        for i in gradient_accumulate_contexts(self.gradient_accumulate_every,
                                              self.is_ddp,
                                              ddps=[D_aug, G]):
            latents = torch.randn(batch_size, latent_dim).cuda(self.rank)
            image_batch = next(self.loader).cuda(self.rank)
            image_batch.requires_grad_()

            with amp_context():
                with torch.no_grad():
                    generated_images = G(latents)

                fake_output, fake_output_32x32, _ = D_aug(generated_images,
                                                          detach=True,
                                                          **aug_kwargs)

                real_output, real_output_32x32, real_aux_loss = D_aug(
                    image_batch, calc_aux_loss=True, **aug_kwargs)

                real_output_loss = real_output
                fake_output_loss = fake_output

                divergence = hinge_loss(real_output_loss, fake_output_loss)
                divergence_32x32 = hinge_loss(real_output_32x32,
                                              fake_output_32x32)
                disc_loss = divergence + divergence_32x32

                aux_loss = real_aux_loss
                disc_loss = disc_loss + aux_loss

            if apply_gradient_penalty:
                outputs = [real_output, real_output_32x32]
                outputs = list(map(self.D_scaler.scale,
                                   outputs)) if self.amp else outputs

                scaled_gradients = torch_grad(
                    outputs=outputs,
                    inputs=image_batch,
                    grad_outputs=list(
                        map(
                            lambda t: torch.ones(t.size(),
                                                 device=image_batch.device),
                            outputs)),
                    create_graph=True,
                    retain_graph=True,
                    only_inputs=True)[0]

                inv_scale = (1. /
                             self.D_scaler.get_scale()) if self.amp else 1.
                gradients = scaled_gradients * inv_scale

                with amp_context():
                    gradients = gradients.reshape(batch_size, -1)
                    gp = self.gp_weight * (
                        (gradients.norm(2, dim=1) - 1)**2).mean()

                    if not torch.isnan(gp):
                        disc_loss = disc_loss + gp
                        self.last_gp_loss = gp.clone().detach().item()

            with amp_context():
                disc_loss = disc_loss / self.gradient_accumulate_every

            disc_loss.register_hook(raise_if_nan)
            self.D_scaler.scale(disc_loss).backward()
            total_disc_loss += divergence

        self.last_recon_loss = aux_loss.item()
        self.d_loss = float(total_disc_loss.item() /
                            self.gradient_accumulate_every)
        self.D_scaler.step(self.GAN.D_opt)
        self.D_scaler.update()

        # train generator

        self.GAN.G_opt.zero_grad()

        for i in gradient_accumulate_contexts(self.gradient_accumulate_every,
                                              self.is_ddp,
                                              ddps=[G, D_aug]):
            latents = torch.randn(batch_size, latent_dim).cuda(self.rank)

            with amp_context():
                generated_images = G(latents)
                fake_output, fake_output_32x32, _ = D_aug(
                    generated_images, **aug_kwargs)
                fake_output_loss = fake_output.mean(
                    dim=1) + fake_output_32x32.mean(dim=1)

                epochs = (self.steps * batch_size *
                          self.gradient_accumulate_every) / len(self.dataset)
                k_frac = max(self.generator_top_k_gamma**epochs,
                             self.generator_top_k_frac)
                k = math.ceil(batch_size * k_frac)

                if k != batch_size:
                    fake_output_loss, _ = fake_output_loss.topk(k=k,
                                                                largest=False)

                loss = fake_output_loss.mean()
                gen_loss = loss

                gen_loss = gen_loss / self.gradient_accumulate_every

            gen_loss.register_hook(raise_if_nan)
            self.G_scaler.scale(gen_loss).backward()
            total_gen_loss += loss

        self.g_loss = float(total_gen_loss.item() /
                            self.gradient_accumulate_every)
        self.G_scaler.step(self.GAN.G_opt)
        self.G_scaler.update()

        # calculate moving averages

        if self.is_main and self.steps % 10 == 0 and self.steps > 20000:
            self.GAN.EMA()

        if self.is_main and self.steps <= 25000 and self.steps % 1000 == 2:
            self.GAN.reset_parameter_averaging()

        # save from NaN errors

        if any(torch.isnan(l) for l in (total_gen_loss, total_disc_loss)):
            print(
                f'NaN detected for generator or discriminator. Loading from checkpoint #{self.checkpoint_num}'
            )
            self.load(self.checkpoint_num)
            raise NanException

        del total_disc_loss
        del total_gen_loss

        # periodically save results

        if self.is_main:
            if self.steps % self.save_every == 0:
                self.save(self.checkpoint_num)

            if self.steps % self.evaluate_every == 0 or (
                    self.steps % 100 == 0 and self.steps < 20000):
                self.evaluate(floor(self.steps / self.evaluate_every))

            if exists(
                    self.calculate_fid_every
            ) and self.steps % self.calculate_fid_every == 0 and self.steps != 0:
                num_batches = math.ceil(CALC_FID_NUM_IMAGES / self.batch_size)
                fid = self.calculate_fid(num_batches)
                self.last_fid = fid

                with open(
                        str(self.results_dir / self.name / f'fid_scores.txt'),
                        'a') as f:
                    f.write(f'{self.steps},{fid}\n')

        self.steps += 1

    @torch.no_grad()
    def evaluate(self, num=0, num_image_tiles=8, trunc=1.0):
        self.GAN.eval()

        ext = self.image_extension
        num_rows = num_image_tiles

        latent_dim = self.GAN.latent_dim
        image_size = self.GAN.image_size

        # latents and noise

        latents = torch.randn((num_rows**2, latent_dim)).cuda(self.rank)

        # regular

        generated_images = self.generate_truncated(self.GAN.G, latents)
        torchvision.utils.save_image(generated_images,
                                     str(self.results_dir / self.name /
                                         f'{str(num)}.{ext}'),
                                     nrow=num_rows)

        # moving averages

        generated_images = self.generate_truncated(self.GAN.GE, latents)
        torchvision.utils.save_image(generated_images,
                                     str(self.results_dir / self.name /
                                         f'{str(num)}-ema.{ext}'),
                                     nrow=num_rows)

    @torch.no_grad()
    def calculate_fid(self, num_batches):
        from pytorch_fid import fid_score
        torch.cuda.empty_cache()

        real_path = str(self.results_dir / self.name / 'fid_real') + '/'
        fake_path = str(self.results_dir / self.name / 'fid_fake') + '/'

        # remove any existing files used for fid calculation and recreate directories
        rmtree(real_path, ignore_errors=True)
        rmtree(fake_path, ignore_errors=True)
        os.makedirs(real_path)
        os.makedirs(fake_path)

        for batch_num in tqdm(range(num_batches),
                              desc='calculating FID - saving reals'):
            real_batch = next(self.loader)
            for k in range(real_batch.size(0)):
                torchvision.utils.save_image(
                    real_batch[k, :, :, :], real_path +
                    '{}.png'.format(k + batch_num * self.batch_size))

        # generate a bunch of fake images in results / name / fid_fake
        self.GAN.eval()
        ext = self.image_extension

        latent_dim = self.GAN.latent_dim
        image_size = self.GAN.image_size

        for batch_num in tqdm(range(num_batches),
                              desc='calculating FID - saving generated'):
            # latents and noise
            latents = torch.randn(self.batch_size, latent_dim).cuda(self.rank)

            # moving averages
            generated_images = self.generate_truncated(self.GAN.GE, latents)

            for j in range(generated_images.size(0)):
                torchvision.utils.save_image(
                    generated_images[j, :, :, :],
                    str(
                        Path(fake_path) /
                        f'{str(j + batch_num * self.batch_size)}-ema.{ext}'))

        return fid_score.calculate_fid_given_paths([real_path, fake_path], 256,
                                                   True, 2048)

    @torch.no_grad()
    def generate_truncated(self, G, style, trunc_psi=0.75, num_image_tiles=8):
        generated_images = evaluate_in_chunks(self.batch_size, G, style)
        return generated_images.clamp_(0., 1.)

    @torch.no_grad()
    def generate_interpolation(self,
                               num=0,
                               num_image_tiles=8,
                               trunc=1.0,
                               num_steps=100,
                               save_frames=False):
        self.GAN.eval()
        ext = self.image_extension
        num_rows = num_image_tiles

        latent_dim = self.GAN.latent_dim
        image_size = self.GAN.image_size

        # latents and noise

        latents_low = torch.randn(num_rows**2, latent_dim).cuda(self.rank)
        latents_high = torch.randn(num_rows**2, latent_dim).cuda(self.rank)

        ratios = torch.linspace(0., 8., num_steps)

        frames = []
        for ratio in tqdm(ratios):
            interp_latents = slerp(ratio, latents_low, latents_high)
            generated_images = self.generate_truncated(self.GAN.GE,
                                                       interp_latents)
            images_grid = torchvision.utils.make_grid(generated_images,
                                                      nrow=num_rows)
            pil_image = transforms.ToPILImage()(images_grid.cpu())

            if self.transparent:
                background = Image.new('RGBA', pil_image.size, (255, 255, 255))
                pil_image = Image.alpha_composite(background, pil_image)

            frames.append(pil_image)

        frames[0].save(str(self.results_dir / self.name / f'{str(num)}.gif'),
                       save_all=True,
                       append_images=frames[1:],
                       duration=80,
                       loop=0,
                       optimize=True)

        if save_frames:
            folder_path = (self.results_dir / self.name / f'{str(num)}')
            folder_path.mkdir(parents=True, exist_ok=True)
            for ind, frame in enumerate(frames):
                frame.save(str(folder_path / f'{str(ind)}.{ext}'))

    def print_log(self):
        data = [('G', self.g_loss), ('D', self.d_loss),
                ('GP', self.last_gp_loss), ('SS', self.last_recon_loss),
                ('FID', self.last_fid)]

        data = [d for d in data if exists(d[1])]
        log = ' | '.join(map(lambda n: f'{n[0]}: {n[1]:.2f}', data))
        print(log)

    def model_name(self, num):
        return str(self.models_dir / self.name / f'model_{num}.pt')

    def init_folders(self):
        (self.results_dir / self.name).mkdir(parents=True, exist_ok=True)
        (self.models_dir / self.name).mkdir(parents=True, exist_ok=True)

    def clear(self):
        rmtree(str(self.models_dir / self.name), True)
        rmtree(str(self.results_dir / self.name), True)
        rmtree(str(self.config_path), True)
        self.init_folders()

    def save(self, num):
        save_data = {
            'GAN': self.GAN.state_dict(),
            'version': __version__,
            'G_scaler': self.G_scaler.state_dict(),
            'D_scaler': self.D_scaler.state_dict()
        }

        torch.save(save_data, self.model_name(num))
        self.write_config()

    def load(self, num=-1):
        self.load_config()

        name = num
        if num == -1:
            file_paths = [
                p for p in Path(self.models_dir / self.name).glob('model_*.pt')
            ]
            saved_nums = sorted(
                map(lambda x: int(x.stem.split('_')[1]), file_paths))
            if len(saved_nums) == 0:
                return
            name = saved_nums[-1]
            print(f'continuing from previous epoch - {name}')

        self.steps = name * self.save_every

        load_data = torch.load(self.model_name(name))

        if 'version' in load_data and self.is_main:
            print(f"loading from version {load_data['version']}")

        try:
            self.GAN.load_state_dict(load_data['GAN'])
        except Exception as e:
            print(
                'unable to load save model. please try downgrading the package to the version specified by the saved model'
            )
            raise e

        if 'G_scaler' in load_data:
            self.G_scaler.load_state_dict(load_data['G_scaler'])
        if 'D_scaler' in load_data:
            self.D_scaler.load_state_dict(load_data['D_scaler'])
def load_checkpoint(
    path: str, device: torch.device
) -> (TEDD1104, str, torch.optim, torch.optim.lr_scheduler, float, int, bool, str):

    """
    Restore checkpoint

    Input:
    -path: path of the checkpoint to restore

    Output:
     - model: restored TEDD1104 model
     - optimizer_name: Name of the optimizer used for training: SGD or Adam
     - optimizer: Optimizer used for training
     - acc_dev: Accuracy of the model in the development set
     - epoch: Num of epoch used to train the model
     - fp16: true if the model uses fp16 else false
     - scaler: If the model uses FP16, the scaler used for training
    """

    checkpoint = torch.load(path)
    dict_hyperparams = checkpoint["hyper_params"]
    model_weights = checkpoint["model"]
    optimizer_name = checkpoint["optimizer_name"]
    optimizer_state = checkpoint["optimizer"]
    acc_dev = checkpoint["acc_dev"]
    epoch = checkpoint["epoch"]
    scaler_state = checkpoint["scaler"]
    fp16 = dict_hyperparams["fp16"]

    model: TEDD1104 = TEDD1104(
        resnet=dict_hyperparams["resnet"],
        pretrained_resnet=dict_hyperparams["pretrained_resnet"],
        sequence_size=dict_hyperparams["sequence_size"],
        embedded_size=dict_hyperparams["embedded_size"],
        hidden_size=dict_hyperparams["hidden_size"],
        num_layers_lstm=dict_hyperparams["num_layers_lstm"],
        bidirectional_lstm=dict_hyperparams["bidirectional_lstm"],
        layers_out=dict_hyperparams["layers_out"],
        dropout_cnn=dict_hyperparams["dropout_cnn"],
        dropout_cnn_out=dict_hyperparams["dropout_cnn_out"],
        dropout_lstm=dict_hyperparams["dropout_lstm"],
        dropout_lstm_out=dict_hyperparams["dropout_lstm_out"],
    ).to(device=device)

    if optimizer_name == "SGD":
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    elif optimizer_name == "Adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    else:
        raise ValueError(
            f"The optimizer you are trying to load is unknown: "
            f"Optimizer name {optimizer_name}. Available optimizers: SGD, Adam"
        )

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True)

    model.load_state_dict(model_weights)
    optimizer.load_state_dict(optimizer_state)
    try:
        scheduler_state = checkpoint["scheduler"]
        scheduler.load_state_dict(scheduler_state)
    except KeyError:
        logging.warning(f"Legacy checkpoint, a new scheduler will be created")

    try:
        running_loss = checkpoint["running_loss"]
    except KeyError:
        logging.warning(
            "Legacy checkpoint, running loss will be initialized with 0.0 value"
        )
        running_loss = 0.0

    try:
        total_training_examples = checkpoint["total_training_examples"]
    except KeyError:
        logging.warning(
            "Legacy checkpoint, total training examples will be initialized with 0 value"
        )
        total_training_examples = 0

    try:
        total_batches = checkpoint["total_batches"]
    except KeyError:
        logging.warning(
            "Legacy checkpoint, total batches will be initialized with 0 value"
        )
        total_batches = 0

    scaler: Optional[GradScaler]
    if fp16:
        scaler = GradScaler()
        scaler.load_state_dict(scaler_state)
    else:
        scaler = None

    return (
        model,
        optimizer_name,
        optimizer,
        scheduler,
        running_loss,
        total_batches,
        total_training_examples,
        acc_dev,
        epoch,
        fp16,
        scaler,
    )