class Trainer():
    def __init__(self,
                 model,
                 dataset,
                 ctx=-1,
                 batch_size=128,
                 optimizer='sgd',
                 lambdas=[0.1, 0.1],
                 print_freq=32):
        self.model = model
        self.dataset = dataset
        self.batch_size = batch_size
        self.optbb = optim.SGD(chain(self.model.age_classifier.parameters(),
                                     self.model.RFM.parameters(),
                                     self.model.margin_fc.parameters(),
                                     self.model.backbone.parameters()),
                               lr=0.01,
                               momentum=0.9)
        self.optDAL = optim.SGD(self.model.DAL.parameters(),
                                lr=0.01,
                                momentum=0.9)
        self.lambdas = lambdas
        self.print_freq = print_freq
        self.id_recorder = Recorder()
        self.age_recorder = Recorder()
        self.trainingDAL = False
        if ctx < 0:
            self.ctx = torch.device('cpu')
        else:
            self.ctx = torch.device(f'cuda:{ctx}')

    def train(self, epochs, start_epoch, save_path=None):
        self.train_ds = ImageFolderWithAgeGroup(self.dataset['pat'], self.dataset['pos'], \
                        age_cutoffs, self.dataset['train_root'], transform=transforms.Compose(\
                            [transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), \
                                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))]))
        self.train_ld = DataLoader(self.train_ds,
                                   shuffle=True,
                                   batch_size=self.batch_size)
        if self.dataset['val_root'] is not None:
            self.val_ds = ImageFolderWithAgeGroup(self.dataset['pat'], self.dataset['pos'], age_cutoffs, self.dataset['val_root'], \
                        transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))]))
            self.val_ld = DataLoader(self.val_ds,
                                     shuffle=True,
                                     batch_size=self.batch_size)
        self.model = self.model.to(self.ctx)
        for epoch in range(epochs):
            print(f'---- epoch {epoch} ----')
            self.update()
            if self.dataset['val_root'] is not None:
                acc = self.validate()
            else:
                acc = -1.
            if save_path is not None:
                torch.save(
                    self.model.state_dict(),
                    os.path.join(save_path,
                                 f'{start_epoch+epoch}_{acc:.4f}.state'))

    def update(self):
        print('    -- Training --')
        self.model.train()
        self.id_recorder.reset()
        self.age_recorder.reset()
        for i, (xs, ys, agegrps) in enumerate(self.train_ld):
            if i % 70 == 0:  # canonical maximization procesure
                self.set_train_mode(False)
            elif i % 70 == 20:  #  RFM optimize procesure
                self.set_train_mode(True)
            xs, ys, agegrps = xs.to(self.ctx), ys.to(self.ctx), agegrps.to(
                self.ctx)
            idLoss, id_acc, ageLoss, age_acc, cc = self.model(xs, ys, agegrps)
            #print(f'        ---\n{idLoss}\n{id_acc}\n{ageLoss}\n{age_acc}\n{cc}')
            total_loss = idLoss + ageLoss * self.lambdas[
                0] + cc * self.lambdas[1]
            self.id_recorder.gulp(len(agegrps), idLoss.item(), id_acc.item())
            self.age_recorder.gulp(len(agegrps), ageLoss.item(),
                                   age_acc.item())
            if i % self.print_freq == 0:
                print(
                    f'        iter: {i} {i%70} total loss: {total_loss.item():.4f} ({idLoss.item():.4f}, {id_acc.item():.4f}, {ageLoss.item():.4f}, {age_acc.item():.4f}, {cc.item():.8f})'
                )
            if self.trainingDAL:
                self.optDAL.zero_grad()
                total_loss.backward()
                Trainer.flip_grads(self.model.DAL)
                self.optDAL.step()
            else:
                self.optbb.zero_grad()
                total_loss.backward()
                self.optbb.step()
        # show average training meta after epoch
        print(f'        {self.id_recorder.excrete().result_as_string()}')
        print(f'        {self.age_recorder.excrete().result_as_string()}')

    def validate(self):
        print('    -- Validating --')
        self.model.eval()
        self.id_recorder.reset()
        self.age_recorder.reset()
        for i, (xs, ys, agegrps) in enumerate(self.val_ld):
            xs, ys, agegrps = xs.to(self.ctx), ys.to(self.ctx), agegrps.to(
                self.ctx)
            with torch.no_grad():
                idLoss, id_acc, ageLoss, age_acc, cc = self.model(
                    xs, ys, agegrps)
                total_loss = idLoss + ageLoss * self.lambdas[
                    0] + cc * self.lambdas[1]
                self.id_recorder.gulp(len(agegrps), idLoss.item(),
                                      id_acc.item())
                self.age_recorder.gulp(len(agegrps), ageLoss.item(),
                                       age_acc.item())
        # show average validation meta after epoch
        print(f'        {self.id_recorder.excrete().result_as_string()}')
        print(f'        {self.age_recorder.excrete().result_as_string()}')
        return self.id_recorder.acc

    def set_train_mode(self, state):
        self.trainingDAL = not state
        Trainer.set_grads(self.model.RFM, state)
        Trainer.set_grads(self.model.backbone, state)
        Trainer.set_grads(self.model.margin_fc, state)
        Trainer.set_grads(self.model.age_classifier, state)
        Trainer.set_grads(self.model.DAL, not state)

    @staticmethod
    def set_grads(mod, state):
        for para in mod.parameters():
            para.requires_grad = state

    @staticmethod
    def flip_grads(mod):
        for para in mod.parameters():
            if para.requires_grad:
                para.grad = -para.grad
class Trainer():
    def __init__(
        self, model, dataset, ctx=-1, batch_size=128, optimizer='sgd', 
        grad_accu=1, lambdas=[0.05, 0.1], print_freq=32, train_head_only=True
    ):
        self.model = model
        self.dataset = dataset
        self.batch_size = batch_size
        self.finetune_layers = (
            # self.model.backbone.repeat_3[-1:], 
            self.model.backbone.last_bn, 
            self.model.backbone.last_linear, self.model.backbone.block8
        )        
        first_group = [
            {
                "params": chain(
                    self.model.age_classifier.parameters(),
                    self.model.RFM.parameters(),
                    self.model.margin_fc.parameters(),
                ),
                "lr": 5e-4
            }
        ]
        if not train_head_only:
            # first_group[0]["lr"] = 1e-4
            first_group.append(
                {
                    "params": chain(
                        *(x.parameters() for x in self.finetune_layers)
                    ),
                    "lr": 5e-5
                }
            )
        self.optbb = RAdam(first_group)
        self.optDAL = RAdam(self.model.DAL.parameters(), lr=5e-4)
        self.lambdas = lambdas
        self.print_freq = print_freq
        self.id_recorder = Recorder()
        self.age_recorder = Recorder()
        self.trainingDAL = False
        if ctx < 0:
            self.ctx = torch.device('cpu')
        else:
            self.ctx = torch.device(f'cuda:{ctx}')
        self.scaler1 = GradScaler()
        self.scaler2 = GradScaler()
        self.grad_accu = grad_accu
        self.train_head_only = train_head_only

    def train(self, epochs, start_epoch, save_path=None):
        self.train_ds = ImageFolderWithAges(
            self.dataset['pat'], self.dataset['pos'],
            transforms=Compose(
                [
                    HorizontalFlip(p=0.5),
                    OneOf([
                        IAAAdditiveGaussianNoise(),
                        GaussNoise(),
                    ], p=0.25),
                    Resize(200, 200, cv2.INTER_AREA),
                    ToTensor(normalize=dict(
                        mean=[0.5, 0.5, 0.5], std=[0.50196, 0.50196, 0.50196])
                    )
                ]
            ),
            root=self.dataset['train_root'],
        )
        self.train_ld = DataLoader(
            self.train_ds, shuffle=True, batch_size=self.batch_size, num_workers=2, drop_last=True, pin_memory=True
        )
        print("# Batches:", len(self.train_ld))
        if self.dataset['val_root'] is not None:
            self.val_ds = ImageFolderWithAges(
                self.dataset['pat'], self.dataset['pos'],
                root=self.dataset['val_root'],
                transforms=Compose([
                    Resize(200, 200, cv2.INTER_AREA),
                    ToTensor(normalize=dict(
                        mean=[0.5, 0.5, 0.5], std=[0.50196, 0.50196, 0.50196])
                    )
                ])
            )
            self.val_ld = DataLoader(self.val_ds, shuffle=False, batch_size=self.batch_size,
                                     pin_memory=True, num_workers=1)
        self.model = self.model.to(self.ctx)
        total_steps = len(self.train_ld) * epochs
        lr_durations = [
            int(total_steps*0.05),
            int(np.ceil(total_steps*0.95))
        ]
        break_points = [0] + list(np.cumsum(lr_durations))[:-1]
        self.schedulers = [
            MultiStageScheduler(
                [
                    LinearLR(self.optbb, 0.01, lr_durations[0]),
                    CosineAnnealingLR(self.optbb, lr_durations[1], eta_min=1e-6)
                ],
                start_at_epochs=break_points
            ),
            MultiStageScheduler(
                [
                    LinearLR(self.optDAL, 0.01, lr_durations[0]),
                    CosineAnnealingLR(self.optDAL, lr_durations[1], eta_min=1e-6)
                ],
                start_at_epochs=break_points
            )
        ]
        if self.train_head_only:
            set_trainable(self.model.backbone, False)
            for module in self.model.backbone.modules():
                if isinstance(module, (nn.BatchNorm2d, nn.BatchNorm1d)):
                    module.track_running_stats = False
        else:
            set_trainable(self.model.backbone, False)
            # for module in self.model.backbone.modules():
            #     if isinstance(module, (nn.BatchNorm2d, nn.BatchNorm1d)):
            #         module.track_running_stats = False
            for module in self.finetune_layers:
                set_trainable(module, True)
                # for submodule in chain([module], module.modules()):
                #     if isinstance(submodule, (nn.BatchNorm2d, nn.BatchNorm1d)):
                #         submodule.track_running_stats = True
        count_model_parameters(self.model)                
        # print(self.optbb.param_groups[-1]["lr"])
        # print(self.optDAL.param_groups[-1]["lr"])
        for epoch in range(epochs):
            print(f'---- epoch {epoch} ----')
            self.update()
            if self.dataset['val_root'] is not None:
                acc = self.validate()
            else:
                acc = -1.
            if save_path is not None:
                torch.save(self.model.state_dict(), os.path.join(save_path, f'{start_epoch+epoch}_{acc:.4f}.pth'))

    def update(self):
        print('    -- Training --')
        self.model.train()
        self.model.backbone.eval()
        if not self.train_head_only:
            for module in self.finetune_layers:
                module.train()
                # for submodule in chain([module], module.modules()):
                #     if isinstance(submodule, (nn.BatchNorm2d, nn.BatchNorm1d)):
                #         submodule.eval()
        self.id_recorder.reset()
        self.age_recorder.reset()
        for i, (xs, ys, agegrps) in enumerate(self.train_ld):
            if i % 80 == 0:  # canonical maximization procesure
                self.set_train_mode(False)
            elif i % 80 == 28:  # RFM optimize procesure
                self.set_train_mode(True)
            xs, ys, agegrps = xs.to(self.ctx), ys.to(self.ctx), agegrps.to(self.ctx)
            with autocast():
                self.model(xs, ys, agegrps=agegrps)
            idLoss, id_acc, ageLoss, age_acc, cc = self.model(xs, ys, agegrps=agegrps)
            #print(f'        ---\n{idLoss}\n{id_acc}\n{ageLoss}\n{age_acc}\n{cc}')
            total_loss = idLoss + ageLoss*self.lambdas[0] + cc*self.lambdas[1]
            total_loss /= self.grad_accu
            self.id_recorder.gulp(len(agegrps), idLoss.detach().item(), id_acc.detach().item())
            self.age_recorder.gulp(len(agegrps), ageLoss.detach().item(), age_acc.detach().item())
            if i % self.print_freq == 0:
                print(
                    f'        iter: {i} {i%70} total loss: {total_loss.item():.4f} ({idLoss.item():.4f}, {id_acc.item():.4f}, {ageLoss.item():.4f}, {age_acc.item():.4f}, {cc.item():.8f}) {self.optbb.param_groups[-1]["lr"]:.6f}')
            if self.trainingDAL:
                self.scaler1.scale(-1 * cc*self.lambdas[1]).backward()
                # total_loss.backward()
                # Trainer.flip_grads(self.model.DAL)
                if (i + 1) % self.grad_accu == 0:
                    # self.optDAL.step()
                    self.scaler1.step(self.optDAL)
                    self.scaler1.update()
                    self.optDAL.zero_grad()
            else:
                self.scaler2.scale(total_loss).backward()
                # total_loss.backward()
                # self.optbb.step()
                if (i + 1) % self.grad_accu == 0:
                    self.scaler2.step(self.optbb)
                    self.scaler2.update()
                    self.optbb.zero_grad()                    
            for scheduler in self.schedulers:
                scheduler.step()
        # show average training meta after epoch
        print(f'        {self.id_recorder.excrete().result_as_string()}')
        print(f'        {self.age_recorder.excrete().result_as_string()}')

    def validate(self):
        print('    -- Validating --')
        self.model.eval()
        self.id_recorder.reset()
        self.age_recorder.reset()
        for i, (xs, ys, agegrps) in enumerate(self.val_ld):
            xs, ys, agegrps = xs.to(self.ctx), ys.to(self.ctx), agegrps.to(self.ctx)
            with torch.no_grad():
                with autocast():
                    idLoss, id_acc, ageLoss, age_acc, cc = self.model(xs, ys, agegrps)
                # total_loss = idLoss + ageLoss*self.lambdas[0] + cc*self.lambdas[1]
                self.id_recorder.gulp(len(agegrps), idLoss.item(), id_acc.item())
                self.age_recorder.gulp(len(agegrps), ageLoss.item(), age_acc.item())
        # show average validation meta after epoch
        print(f'        {self.id_recorder.excrete().result_as_string()}')
        print(f'        {self.age_recorder.excrete().result_as_string()}')
        return self.id_recorder.acc

    def set_train_mode(self, state):
        self.trainingDAL = not state
    #     Trainer.set_grads(self.model.RFM, state)
    #     # Trainer.set_grads(self.model.backbone, state)
    #     Trainer.set_grads(self.model.margin_fc, state)
    #     Trainer.set_grads(self.model.age_classifier, state)
    #     Trainer.set_grads(self.model.DAL, not state)

    @staticmethod
    def set_grads(mod, state):
        for para in mod.parameters():
            para.requires_grad = state

    @staticmethod
    def flip_grads(mod):
        for para in mod.parameters():
            if para.requires_grad:
                para.grad = - para.grad