class Trainer:
    """ Train and Validation with single GPU """
    def __init__(self, train_loader, val_loader, args):
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.args = args
        self.model = get_model(args)
        self.epochs = args.epochs
        self.total_step = len(train_loader) * args.epochs
        self.step = 0
        self.epoch = 0
        self.start_epoch = 1
        self.lr = args.learning_rate
        self.best_acc = 0

        # Log
        self.log_path = (
                PROJECT_ROOT / Path(SAVE_DIR) /
                Path(datetime.now().strftime("%Y%m%d%H%M%S") + "-")
                ).as_posix()
        self.log_path = Path(self.get_dirname(self.log_path, args))
        if not Path.exists(self.log_path):
            Path(self.log_path).mkdir(parents=True, exist_ok=True)
        self.logger = Logger("train", self.log_path, args.verbose)
        self.logger.log("Checkpoint files will be saved in {}".format(self.log_path))

        self.logger.add_level('STEP', 21, 'green')
        self.logger.add_level('EPOCH', 22, 'cyan')
        self.logger.add_level('EVAL', 23, 'yellow')

        self.criterion = nn.CrossEntropyLoss()
        if self.args.cuda:
            self.criterion = self.criterion.cuda()
        if args.half:
            self.model.half()
            self.criterion.half()

        params = self.model.parameters()
        self.optimizer = get_optimizer(args.optimizer, params, args)

    def train(self):
        self.eval()
        for self.epoch in range(self.start_epoch, self.args.epochs+1):
            self.adjust_learning_rate([int(self.args.epochs/2), int(self.args.epochs*3/4)], factor=0.1)
            self.train_epoch()
            self.eval()

        self.logger.writer.export_scalars_to_json(
            self.log_path.as_posix() + "/scalars-{}-{}-{}.json".format(
                self.args.model,
                self.args.seed,
                self.args.activation
            )
        )
        self.logger.writer.close()

    def train_epoch(self):
        self.model.train()
        eval_metrics = EvaluationMetrics(['Loss', 'Acc', 'Time'])

        for i, (images, labels) in enumerate(self.train_loader):
            st = time.time()
            self.step += 1
            images = torch.autograd.Variable(images)
            labels = torch.autograd.Variable(labels)
            if self.args.cuda:
                images = images.cuda()
                labels = labels.cuda()
            if self.args.half: images = images.half()

            outputs, loss = self.compute_loss(images, labels)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            outputs = outputs.float()
            loss = loss.float()
            elapsed_time = time.time() - st

            _, preds = torch.max(outputs, 1)
            accuracy = (labels == preds.squeeze()).float().mean()

            batch_size = labels.size(0)
            eval_metrics.update('Loss', float(loss), batch_size)
            eval_metrics.update('Acc', float(accuracy), batch_size)
            eval_metrics.update('Time', elapsed_time, batch_size)

            if self.step % self.args.log_step == 0:
                self.logger.scalar_summary(eval_metrics.val, self.step, 'STEP')

        # Histogram of parameters
        for tag, p in self.model.named_parameters():
            tag = tag.split(".")
            tag = "Train_{}".format(tag[0]) + "/" + "/".join(tag[1:])
            try:
                self.logger.writer.add_histogram(tag, p.clone().cpu().data.numpy(), self.step)
                self.logger.writer.add_histogram(tag+'/grad', p.grad.clone().cpu().data.numpy(), self.step)
            except Exception as e:
                print("Check if variable {} is not used: {}".format(tag, e))

        self.logger.scalar_summary(eval_metrics.avg, self.step, 'EPOCH')


    def eval(self):
        self.model.eval()
        eval_metrics = EvaluationMetrics(['Loss', 'Acc', 'Time'])

        for i, (images, labels) in enumerate(self.val_loader):
            st = time.time()
            images = torch.autograd.Variable(images)
            labels = torch.autograd.Variable(labels)
            if self.args.cuda:
                images = images.cuda()
                labels = labels.cuda()
            if self.args.half: images = images.half()

            outputs, loss = self.compute_loss(images, labels)

            outputs = outputs.float()
            loss = loss.float()
            elapsed_time = time.time() - st

            _, preds = torch.max(outputs, 1)
            accuracy = (labels == preds.squeeze()).float().mean()

            batch_size = labels.size(0)
            eval_metrics.update('Loss', float(loss), batch_size)
            eval_metrics.update('Acc', float(accuracy), batch_size)
            eval_metrics.update('Time', elapsed_time, batch_size)

        # Save best model
        if eval_metrics.avg['Acc'] > self.best_acc:
            self.save()
            self.logger.log("Saving best model: epoch={}".format(self.epoch))
            self.best_acc = eval_metrics.avg['Acc']
            self.maybe_delete_old_pth(log_path=self.log_path.as_posix(), max_to_keep=1)

        self.logger.scalar_summary(eval_metrics.avg, self.step, 'EVAL')

    def get_dirname(self, path, args):
        path += "{}-".format(getattr(args, 'dataset'))
        path += "{}-".format(getattr(args, 'seed'))
        path += "{}".format(getattr(args, 'model'))
        return path

    def save(self, filename=None):
        if filename is None:
            filename = os.path.join(self.log_path, 'model-{}.pth'.format(self.epoch))
        torch.save({
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'epoch': self.start_epoch,
            'best_acc': self.best_acc,
            'args': self.args
        }, filename)

    def load(self, filename=None):
        if filename is None: filename = self.log_path
        S = torch.load(filename) if self.args.cuda else torch.load(filename, map_location=lambda storage, location: storage)
        self.model.load_state_dict(S['model'])
        self.optimizer.load_state_dict(S['optimizer'])
        self.epoch = S['epoch']
        self.best_acc = S['best_acc']
        self.args = S['args']

    def maybe_delete_old_pth(self, log_path, max_to_keep):
        """Model filename must end with xxx-xxx-[epoch].pth
        """
        # filename and time
        pths = [(f, int(f[:-4].split("-")[-1])) for f in os.listdir(log_path) if f.endswith('.pth')]
        if len(pths) > max_to_keep:
            sorted_pths = sorted(pths, key=lambda tup: tup[1])
            for i in range(len(pths) - max_to_keep):
                os.remove(os.path.join(log_path, sorted_pths[i][0]))

    def show_current_model(self):
        print("\n".join("{}: {}".format(k, v) for k, v in sorted(vars(self.args).items())))

        model_parameters = filter(lambda p: p.requires_grad, self.model.parameters())
        total_params = np.sum([np.prod(p.size()) for p in model_parameters])

        print('%s\n\n'%(type(self.model)))
        print('%s\n\n'%(inspect.getsource(self.model.__init__)))
        print('%s\n\n'%(inspect.getsource(self.model.forward)))

        # Total 95
        print("*"*40 + "%10s" % self.args.model + "*"*45)
        print("*"*40 + "PARAM INFO" + "*"*45)
        print("-"*95)
        print("| %40s | %25s | %20s |" % ("Param Name", "Shape", "Number of Params"))
        print("-"*95)
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                print("| %40s | %25s | %20d |" % (name, list(param.size()), np.prod(param.size())))
        print("-"*95)
        print("Total Params: %d" % (total_params))
        print("*"*95)

    def adjust_learning_rate(self, milestone, factor=0.1):
        if self.epoch in milestone:
            self.lr *= factor
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = self.lr

    def compute_loss(self, images, labels):
        outputs = self.model(images)
        loss = self.criterion(outputs, labels)
        return outputs, loss
class Defender(Trainer):
    """ Perform various adversarial attacks and defense on a pretrained model
    Scheme generates Tensor, not Variable
    """
    def __init__(self, val_loader, args, **kwargs):
        self.val_loader = val_loader
        self.args = args
        self.model = get_model(args)
        self.step = 0
        self.cuda = self.args.cuda

        self.log_path = (
            PROJECT_ROOT / Path("experiments") /
            Path(datetime.now().strftime("%Y%m%d%H%M%S") + "-")).as_posix()
        self.log_path = Path(self.get_dirname(self.log_path, args))
        if not Path.exists(self.log_path):
            Path(self.log_path).mkdir(parents=True, exist_ok=True)
        self.logger = Logger("defense", self.log_path, args.verbose)
        self.logger.log("Checkpoint files will be saved in {}".format(
            self.log_path))

        self.logger.add_level("ATTACK", 21, 'yellow')
        self.logger.add_level("DEFENSE", 22, 'cyan')
        self.logger.add_level("TEST", 23, 'white')
        self.logger.add_level("DIST", 11, 'white')

        self.kwargs = kwargs
        if args.domain_restrict:
            self.artifact = get_artifact(self.model, val_loader, args)
            self.kwargs['artifact'] = self.artifact

    def defend(self):
        self.model.eval()
        defense_scheme = getattr(defenses,
                                 self.args.defense)(self.model, self.args,
                                                    **self.kwargs)
        source = self.model
        if self.args.source is not None and (self.args.ckpt_name !=
                                             self.args.ckpt_src):
            target = self.args.ckpt_name
            self.args.model = self.args.source
            self.args.ckpt_name = self.args.ckpt_src
            source = get_model(self.args)
            self.logger.log("Transfer attack from {} -> {}".format(
                self.args.ckpt_src, target))
        attack_scheme = getattr(attacks, self.args.attack)(source, self.args,
                                                           **self.kwargs)

        eval_metrics = EvaluationMetrics(
            ['Test/Acc', 'Test/Top5', 'Test/Time'])
        eval_def_metrics = EvaluationMetrics(
            ['Def-Test/Acc', 'Def-Test/Top5', 'Def-Test/Time'])
        attack_metrics = EvaluationMetrics(
            ['Attack/Acc', 'Attack/Top5', 'Attack/Time'])
        defense_metrics = EvaluationMetrics(
            ['Defense/Acc', 'Defense/Top5', 'Defense/Time'])
        dist_metrics = EvaluationMetrics(['L0', 'L1', 'L2', 'Li'])

        for i, (images, labels) in enumerate(self.val_loader):
            self.step += 1
            if self.cuda:
                images = images.cuda()
                labels = labels.cuda()
            if self.args.half: images = images.half()

            # Inference
            st = time.time()
            outputs = self.model(self.to_var(images, self.cuda, True))
            outputs = outputs.float()
            _, preds = torch.topk(outputs, 5)

            acc = (labels == preds.data[:, 0]).float().mean()
            top5 = torch.sum(
                (labels.unsqueeze(1).repeat(1, 5) == preds.data).float(),
                dim=1).mean()
            eval_metrics.update('Test/Acc', float(acc), labels.size(0))
            eval_metrics.update('Test/Top5', float(top5), labels.size(0))
            eval_metrics.update('Test/Time', time.time() - st, labels.size(0))

            # Attacker
            st = time.time()
            adv_images, adv_labels = attack_scheme.generate(images, labels)
            if isinstance(adv_images, Variable):
                adv_images = adv_images.data
            attack_metrics.update('Attack/Time',
                                  time.time() - st, labels.size(0))

            # Lp distance
            diff = torch.abs(
                denormalize(adv_images, self.args.dataset) -
                denormalize(images, self.args.dataset))
            L0 = torch.sum((torch.sum(diff, dim=1) > 1e-3).float().view(
                labels.size(0), -1),
                           dim=1).mean()
            diff = diff.view(labels.size(0), -1)
            L1 = torch.norm(diff, p=1, dim=1).mean()
            L2 = torch.norm(diff, p=2, dim=1).mean()
            Li = torch.max(diff, dim=1)[0].mean()
            dist_metrics.update('L0', float(L0), labels.size(0))
            dist_metrics.update('L1', float(L1), labels.size(0))
            dist_metrics.update('L2', float(L2), labels.size(0))
            dist_metrics.update('Li', float(Li), labels.size(0))

            # Defender
            st = time.time()
            def_images, def_labels = defense_scheme.generate(
                adv_images, adv_labels)
            if isinstance(
                    def_images, Variable
            ):  # FIXME - Variable in Variable out for all methods
                def_images = def_images.data
            defense_metrics.update('Defense/Time',
                                   time.time() - st, labels.size(0))
            self.calc_stats('Attack', adv_images, images, adv_labels, labels,
                            attack_metrics)
            self.calc_stats('Defense', def_images, images, def_labels, labels,
                            defense_metrics)

            # Defense-Inference for shift of original image
            st = time.time()
            def_images_org, _ = defense_scheme.generate(images, labels)
            if isinstance(
                    def_images_org, Variable
            ):  # FIXME - Variable in Variable out for all methods
                def_images_org = def_images_org.data
            outputs = self.model(self.to_var(def_images_org, self.cuda, True))
            outputs = outputs.float()
            _, preds = torch.topk(outputs, 5)

            acc = (labels == preds.data[:, 0]).float().mean()
            top5 = torch.sum(
                (labels.unsqueeze(1).repeat(1, 5) == preds.data).float(),
                dim=1).mean()
            eval_def_metrics.update('Def-Test/Acc', float(acc), labels.size(0))
            eval_def_metrics.update('Def-Test/Top5', float(top5),
                                    labels.size(0))
            eval_def_metrics.update('Def-Test/Time',
                                    time.time() - st, labels.size(0))

            if self.step % self.args.log_step == 0 or self.step == len(
                    self.val_loader):
                self.logger.scalar_summary(eval_metrics.avg, self.step, 'TEST')
                self.logger.scalar_summary(eval_def_metrics.avg, self.step,
                                           'TEST')
                self.logger.scalar_summary(attack_metrics.avg, self.step,
                                           'ATTACK')
                self.logger.scalar_summary(defense_metrics.avg, self.step,
                                           'DEFENSE')
                self.logger.scalar_summary(dist_metrics.avg, self.step, 'DIST')

                defense_rate = eval_metrics.avg[
                    'Test/Acc'] - defense_metrics.avg['Defense/Acc']
                if eval_metrics.avg['Test/Acc'] - attack_metrics.avg[
                        'Attack/Acc']:
                    defense_rate /= eval_metrics.avg[
                        'Test/Acc'] - attack_metrics.avg['Attack/Acc']
                else:
                    defense_rate = 0
                defense_rate = 1 - defense_rate

                defense_top5 = eval_metrics.avg[
                    'Test/Top5'] - defense_metrics.avg['Defense/Top5']
                if eval_metrics.avg['Test/Top5'] - attack_metrics.avg[
                        'Attack/Top5']:
                    defense_top5 /= eval_metrics.avg[
                        'Test/Top5'] - attack_metrics.avg['Attack/Top5']
                else:
                    defense_top5 = 0
                defense_top5 = 1 - defense_top5

                self.logger.log(
                    "Defense Rate Top1: {:5.3f} | Defense Rate Top5: {:5.3f}".
                    format(defense_rate, defense_top5), 'DEFENSE')

            if self.step % self.args.img_log_step == 0:
                image_dict = {
                    'Original':
                    to_np(denormalize(images, self.args.dataset))[0],
                    'Attacked':
                    to_np(denormalize(adv_images, self.args.dataset))[0],
                    'Defensed':
                    to_np(denormalize(def_images, self.args.dataset))[0],
                    'Perturbation':
                    to_np(denormalize(images - adv_images,
                                      self.args.dataset))[0]
                }
                self.logger.image_summary(image_dict, self.step)

    def calc_stats(self, method, gen_images, images, gen_labels, labels,
                   metrics):
        """gen_images: Generated from attacker or defender
        Currently just calculating acc and artifact
        """
        success_rate = 0

        if not isinstance(gen_images, Variable):
            gen_images = self.to_var(gen_images.clone(), self.cuda, True)
        gen_outputs = self.model(gen_images)
        gen_outputs = gen_outputs.float()
        _, gen_preds = torch.topk(F.softmax(gen_outputs, dim=1), 5)

        if isinstance(gen_preds, Variable):
            gen_preds = gen_preds.data
        gen_acc = (labels == gen_preds[:, 0]).float().mean()
        gen_top5 = torch.sum(
            (labels.unsqueeze(1).repeat(1, 5) == gen_preds).float(),
            dim=1).mean()

        metrics.update('{}/Acc'.format(method), float(gen_acc), labels.size(0))
        metrics.update('{}/Top5'.format(method), float(gen_top5),
                       labels.size(0))

    def to_var(self, x, cuda, volatile=False):
        """For CPU inference manual cuda setting is needed
        """
        if cuda:
            x = x.cuda()
        return torch.autograd.Variable(x, volatile=volatile)