def loop(self):
        """
        Main loop for training and testing, saving ...
        """

        while self.epoch < self.args.epochs:
            log('[Training] %s' % self.scheduler.report())

            # Note that we test first, to also get the error of the untrained model.
            testing = elapsed(functools.partial(self.test))
            training = elapsed(functools.partial(self.train))
            log('[Training] %gs training, %gs testing' % (training, testing))

            if self.args.early_stopping:
                validation = elapsed(functools.partial(self.validate))
                log('[Training] %gs validation' % validation)

            # Save model checkpoint after each epoch.
            utils.remove(self.args.state_file + '.%d' % (self.epoch - 1))
            State.checkpoint(self.model, self.scheduler.optimizer, self.epoch,
                             self.args.state_file + '.%d' % self.epoch)
            log('[Training] %d: checkpoint' % self.epoch)
            torch.cuda.empty_cache()  # necessary?

            # Save statistics and plots.
            if self.args.training_file:
                utils.write_hdf5(self.args.training_file,
                                 self.train_statistics)
                log('[Training] %d: wrote %s' %
                    (self.epoch, self.args.training_file))
            if self.args.testing_file:
                utils.write_hdf5(self.args.testing_file, self.test_statistics)
                log('[Training] %d: wrote %s' %
                    (self.epoch, self.args.testing_file))

            if utils.display():
                self.plot()
            self.epoch += 1  # !

        # Final testing.
        testing = elapsed(functools.partial(self.test))
        log('[Training] %gs testing' % (testing))

        # Save model checkpoint after each epoch.
        utils.remove(self.args.state_file + '.%d' % (self.epoch - 1))
        State.checkpoint(self.model, self.scheduler.optimizer, self.epoch,
                         self.args.state_file)
        log('[Training] %d: checkpoint' % self.epoch)

        self.results = {
            'training_statistics': self.train_statistics,
            'testing_statistics': self.test_statistics,
        }
        if self.args.results_file:
            utils.write_pickle(self.args.results_file, self.results)
            log('[Training] wrote %s' % self.args.results_file)
    def validate(self):
        """
        Validate for early stopping.
        """

        self.model.eval()
        log('[Training] %d set classifier to eval' % self.epoch)
        assert self.model.training is False

        loss = 0
        error = 0
        num_batches = int(
            math.ceil(self.val_images.shape[0] / self.args.batch_size))

        for b in range(num_batches):
            perm = numpy.take(range(self.val_images.shape[0]),
                              range(b * self.args.batch_size,
                                    (b + 1) * self.args.batch_size),
                              mode='clip')
            batch_images = common.torch.as_variable(self.val_images[perm],
                                                    self.args.use_gpu)
            batch_classes = common.torch.as_variable(self.val_codes[perm],
                                                     self.args.use_gpu)
            batch_images = batch_images.permute(0, 3, 1, 2)

            output_classes = self.model(batch_images)

            e = self.loss(batch_classes, output_classes)
            loss += e.item()
            e = self.error(batch_classes, output_classes)
            error += e.item()

        loss /= num_batches
        error /= num_batches
        log('[Training] %d: val %g (%g)' % (self.epoch, loss, error))

        if self.val_error is None or error < self.val_error:
            self.val_error = error
            State.checkpoint(self.model, self.scheduler.optimizer, self.epoch,
                             self.args.state_file + '.es')
            log('[Training] %d: early stopping checkoint' % self.epoch)
def run(args):

    # Get the data
    train_data = CleanDataset(paths.train_images_file(
        args.dataset), paths.train_labels_file(args.dataset))
    test_data = CleanDataset(paths.test_images_file(
        args.dataset), paths.test_labels_file(args.dataset))

    trainset = DataLoader(train_data, batch_size=args.batch_size, shuffle=True)
    testset = DataLoader(test_data, batch_size=args.batch_size, shuffle=False)

    # Create or load saved model
    if args.saved_model_file:
        state = State.load(paths.experiment_file(
            args.models_dir, args.saved_model_file))
        model = state.model
        if args.cuda:
            model.cuda()
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr,
                                    momentum=args.momentum, weight_decay=args.weight_decay)
        scheduler = common.train.get_exponential_scheduler(
            optimizer, batches_per_epoch=len(trainset), gamma=args.lr_decay)
        optimizer.load_state_dict(state.optimizer)
        for st in optimizer.state.values():
            for k, v in st.items():
                if torch.is_tensor(v):
                    st[k] = v.cuda()
        scheduler.load_state_dict(state.scheduler)
        initial_epoch = state.epoch
    else:
        model = models.ResNet(args.n_classes, [3, 32, 32], channels=12, blocks=[
                              3, 3, 3], clamp=True)
        if args.cuda:
            model.cuda()
        optimizer = SGD(model.parameters(), lr=args.lr,
                        momentum=args.momentum, weight_decay=args.weight_decay)
        scheduler = common.train.get_exponential_scheduler(
            optimizer, batches_per_epoch=len(trainset), gamma=args.lr_decay)
        initial_epoch = -1

    # Logging
    if args.use_tensorboard:
        from torch.utils.tensorboard import SummaryWriter
    else:
        from common.summary import SummaryWriter
    writer = SummaryWriter(paths.log_dir(args.log_dir), max_queue=100)

    # Augmentation parameters
    augmentation_crop = True
    augmentation_contrast = True
    augmentation_add = False
    augmentation_saturation = False
    augmentation_value = False
    augmentation_flip = args.use_flip
    augmentation = common.imgaug.get_augmentation(noise=False, crop=augmentation_crop, flip=augmentation_flip, contrast=augmentation_contrast,
                                                  add=augmentation_add, saturation=augmentation_saturation, value=augmentation_value)

    # Create attack objects
    img_dims = (32, 32)
    if args.location == 'random':
        mask_gen = MaskGenerator(img_dims, tuple(args.mask_dims),
                                 exclude_list=np.array([args.exclude_box]))
    else:
        mask_gen = MaskGenerator(img_dims, tuple(args.mask_dims),
                                 include_list=np.array([args.mask_pos + args.mask_dims]))
    attack = AdversarialPatch(mask_gen, args.epsilon, args.iterations,
                              args.optimize_location, args.opt_type, args.stride, args.signed_grad)
    attack.norm = LInfNorm()
    objective = UntargetedF0Objective()

    if args.mode == 'adversarial':
        trainer = common.train.AdversarialTraining(model, trainset, testset, optimizer, scheduler,
                                                   attack, objective, fraction=args.adv_frac, augmentation=augmentation, writer=writer, cuda=args.cuda)
    elif args.mode == 'normal':
        trainer = common.train.NormalTraining(
            model, trainset, testset, optimizer, scheduler, augmentation=augmentation, writer=writer, cuda=args.cuda)

    trainer.summary_gradients = False

    # Train model
    for e in range(initial_epoch + 1, args.epochs):
        trainer.step(e)
        writer.flush()

        # Save model snapshot
        if (e + 1) % args.snapshot_frequency == 0:
            State.checkpoint(paths.experiment_file(
                args.models_dir, args.model_prefix + '_' + str(e + 1)), model, optimizer, scheduler, e)

    # Save final model
    State.checkpoint(paths.experiment_file(
        args.models_dir, args.model_prefix + '_complete_' + str(e + 1)), model, optimizer, scheduler, args.epochs)
    def loop(self):
        """
        Main loop for training and testing, saving ...
        """

        auto_encoder_params = {
            'lr': self.args.base_lr,
            'lr_decay': self.args.base_lr_decay,
            'lr_min': 0.000000001,
            'weight_decay': self.args.weight_decay
        }

        classifier_params = {
            'lr': self.args.base_lr,
            'lr_decay': self.args.base_lr_decay,
            'lr_min': 0.000000001,
            'weight_decay': self.args.weight_decay
        }

        e = 0
        if os.path.exists(self.args.encoder_file) and os.path.exists(
                self.args.decoder_file) and os.path.exists(
                    self.args.classifier_file):
            state = State.load(self.args.encoder_file)
            log('[Training] loaded %s' % self.args.encoder_file)
            self.encoder.load_state_dict(state.model)
            log('[Training] loaded encoder')

            if self.args.use_gpu and not cuda.is_cuda(self.encoder):
                self.encoder = self.encoder.cuda()

            optimizer = torch.optim.Adam(list(self.encoder.parameters()),
                                         auto_encoder_params['lr'])
            optimizer.load_state_dict(state.optimizer)
            self.encoder_scheduler = ADAMScheduler(optimizer,
                                                   **auto_encoder_params)

            state = State.load(self.args.decoder_file)
            log('[Training] loaded %s' % self.args.decoder_file)
            self.decoder.load_state_dict(state.model)
            log('[Training] loaded decoder')

            if self.args.use_gpu and not cuda.is_cuda(self.decoder):
                self.decoder = self.decoder.cuda()

            optimizer = torch.optim.Adam(list(self.decoder.parameters()),
                                         auto_encoder_params['lr'])
            optimizer.load_state_dict(state.optimizer)
            self.decoder_scheduler = ADAMScheduler(optimizer,
                                                   **auto_encoder_params)

            state = State.load(self.args.classifier_file)
            log('[Training] loaded %s' % self.args.classifier_file)
            self.classifier.load_state_dict(state.model)
            log('[Training] loaded decoder')

            if self.args.use_gpu and not cuda.is_cuda(self.classifier):
                self.classifier = self.classifier.cuda()

            optimizer = torch.optim.Adam(list(self.classifier.parameters()),
                                         classifier_params['lr'])
            optimizer.load_state_dict(state.optimizer)
            self.classifier_scheduler = ADAMScheduler(optimizer,
                                                      **classifier_params)

            e = state.epoch + 1
            self.encoder_scheduler.update(e)
            self.decoder_scheduler.udpate(e)
            self.classifier_scheduler.update(e)
        else:
            if self.args.use_gpu and not cuda.is_cuda(self.encoder):
                self.encoder = self.encoder.cuda()
            if self.args.use_gpu and not cuda.is_cuda(self.decoder):
                self.decoder = self.decoder.cuda()
            if self.args.use_gpu and not cuda.is_cuda(self.classifier):
                self.classifier = self.classifier.cuda()

            self.encoder_scheduler = ADAMScheduler(
                list(self.encoder.parameters()), **auto_encoder_params)
            self.encoder_scheduler.initialize()  # !

            self.decoder_scheduler = ADAMScheduler(
                list(self.decoder.parameters()), **auto_encoder_params)
            self.decoder_scheduler.initialize()  # !

            self.classifier_scheduler = ADAMScheduler(
                list(self.classifier.parameters()), **classifier_params)
            self.classifier_scheduler.initialize()  # !

        log('[Training] model needs %gMiB' %
            (cuda.estimate_size(self.encoder) / (1024 * 1024)))

        while e < self.args.epochs:
            log('[Training] %s' % self.encoder_scheduler.report())
            log('[Training] %s' % self.decoder_scheduler.report())
            log('[Training] %s' % self.classifier_scheduler.report())

            testing = elapsed(functools.partial(self.test, e))
            training = elapsed(functools.partial(self.train, e))
            log('[Training] %gs training, %gs testing' % (training, testing))

            #utils.remove(self.args.encoder_file + '.%d' % (e - 1))
            #utils.remove(self.args.decoder_file + '.%d' % (e - 1))
            #utils.remove(self.args.classifier_file + '.%d' % (e - 1))
            State.checkpoint(self.encoder, self.encoder_scheduler.optimizer, e,
                             self.args.encoder_file + '.%d' % e)
            State.checkpoint(self.decoder, self.decoder_scheduler.optimizer, e,
                             self.args.decoder_file + '.%d' % e)
            State.checkpoint(self.classifier,
                             self.classifier_scheduler.optimizer, e,
                             self.args.classifier_file + '.%d' % e)

            log('[Training] %d: checkpoint' % e)
            torch.cuda.empty_cache()  # necessary?

            # Save statistics and plots.
            if self.args.training_file:
                utils.write_hdf5(self.args.training_file,
                                 self.train_statistics)
                log('[Training] %d: wrote %s' % (e, self.args.training_file))
            if self.args.testing_file:
                utils.write_hdf5(self.args.testing_file, self.test_statistics)
                log('[Training] %d: wrote %s' % (e, self.args.testing_file))

            #if utils.display():
            #    self.plot()

            e += 1  # !

        testing = elapsed(functools.partial(self.test, e))
        log('[Training] %gs testing' % (testing))

        #utils.remove(self.args.encoder_file + '.%d' % (e - 1))
        #utils.remove(self.args.decoder_file + '.%d' % (e - 1))
        #utils.remove(self.args.classifier_file + '.%d' % (e - 1))
        State.checkpoint(self.encoder, self.encoder_scheduler.optimizer, e,
                         self.args.encoder_file)
        State.checkpoint(self.decoder, self.decoder_scheduler.optimizer, e,
                         self.args.decoder_file)
        State.checkpoint(self.classifier, self.classifier_scheduler.optimizer,
                         e, self.args.classifier_file)

        self.results = {
            'training_statistics': self.train_statistics,
            'testing_statistics': self.test_statistics,
        }
        if self.args.results_file:
            utils.write_pickle(self.args.results_file, self.results)
            log('[Training] wrote %s' % self.args.results_file)