def loop(self): """ Main loop for training and testing, saving ... """ while self.epoch < self.args.epochs: log('[Training] %s' % self.scheduler.report()) # Note that we test first, to also get the error of the untrained model. testing = elapsed(functools.partial(self.test)) training = elapsed(functools.partial(self.train)) log('[Training] %gs training, %gs testing' % (training, testing)) if self.args.early_stopping: validation = elapsed(functools.partial(self.validate)) log('[Training] %gs validation' % validation) # Save model checkpoint after each epoch. utils.remove(self.args.state_file + '.%d' % (self.epoch - 1)) State.checkpoint(self.model, self.scheduler.optimizer, self.epoch, self.args.state_file + '.%d' % self.epoch) log('[Training] %d: checkpoint' % self.epoch) torch.cuda.empty_cache() # necessary? # Save statistics and plots. if self.args.training_file: utils.write_hdf5(self.args.training_file, self.train_statistics) log('[Training] %d: wrote %s' % (self.epoch, self.args.training_file)) if self.args.testing_file: utils.write_hdf5(self.args.testing_file, self.test_statistics) log('[Training] %d: wrote %s' % (self.epoch, self.args.testing_file)) if utils.display(): self.plot() self.epoch += 1 # ! # Final testing. testing = elapsed(functools.partial(self.test)) log('[Training] %gs testing' % (testing)) # Save model checkpoint after each epoch. utils.remove(self.args.state_file + '.%d' % (self.epoch - 1)) State.checkpoint(self.model, self.scheduler.optimizer, self.epoch, self.args.state_file) log('[Training] %d: checkpoint' % self.epoch) self.results = { 'training_statistics': self.train_statistics, 'testing_statistics': self.test_statistics, } if self.args.results_file: utils.write_pickle(self.args.results_file, self.results) log('[Training] wrote %s' % self.args.results_file)
def validate(self): """ Validate for early stopping. """ self.model.eval() log('[Training] %d set classifier to eval' % self.epoch) assert self.model.training is False loss = 0 error = 0 num_batches = int( math.ceil(self.val_images.shape[0] / self.args.batch_size)) for b in range(num_batches): perm = numpy.take(range(self.val_images.shape[0]), range(b * self.args.batch_size, (b + 1) * self.args.batch_size), mode='clip') batch_images = common.torch.as_variable(self.val_images[perm], self.args.use_gpu) batch_classes = common.torch.as_variable(self.val_codes[perm], self.args.use_gpu) batch_images = batch_images.permute(0, 3, 1, 2) output_classes = self.model(batch_images) e = self.loss(batch_classes, output_classes) loss += e.item() e = self.error(batch_classes, output_classes) error += e.item() loss /= num_batches error /= num_batches log('[Training] %d: val %g (%g)' % (self.epoch, loss, error)) if self.val_error is None or error < self.val_error: self.val_error = error State.checkpoint(self.model, self.scheduler.optimizer, self.epoch, self.args.state_file + '.es') log('[Training] %d: early stopping checkoint' % self.epoch)
def run(args): # Get the data train_data = CleanDataset(paths.train_images_file( args.dataset), paths.train_labels_file(args.dataset)) test_data = CleanDataset(paths.test_images_file( args.dataset), paths.test_labels_file(args.dataset)) trainset = DataLoader(train_data, batch_size=args.batch_size, shuffle=True) testset = DataLoader(test_data, batch_size=args.batch_size, shuffle=False) # Create or load saved model if args.saved_model_file: state = State.load(paths.experiment_file( args.models_dir, args.saved_model_file)) model = state.model if args.cuda: model.cuda() optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) scheduler = common.train.get_exponential_scheduler( optimizer, batches_per_epoch=len(trainset), gamma=args.lr_decay) optimizer.load_state_dict(state.optimizer) for st in optimizer.state.values(): for k, v in st.items(): if torch.is_tensor(v): st[k] = v.cuda() scheduler.load_state_dict(state.scheduler) initial_epoch = state.epoch else: model = models.ResNet(args.n_classes, [3, 32, 32], channels=12, blocks=[ 3, 3, 3], clamp=True) if args.cuda: model.cuda() optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) scheduler = common.train.get_exponential_scheduler( optimizer, batches_per_epoch=len(trainset), gamma=args.lr_decay) initial_epoch = -1 # Logging if args.use_tensorboard: from torch.utils.tensorboard import SummaryWriter else: from common.summary import SummaryWriter writer = SummaryWriter(paths.log_dir(args.log_dir), max_queue=100) # Augmentation parameters augmentation_crop = True augmentation_contrast = True augmentation_add = False augmentation_saturation = False augmentation_value = False augmentation_flip = args.use_flip augmentation = common.imgaug.get_augmentation(noise=False, crop=augmentation_crop, flip=augmentation_flip, contrast=augmentation_contrast, add=augmentation_add, saturation=augmentation_saturation, value=augmentation_value) # Create attack objects img_dims = (32, 32) if args.location == 'random': mask_gen = MaskGenerator(img_dims, tuple(args.mask_dims), exclude_list=np.array([args.exclude_box])) else: mask_gen = MaskGenerator(img_dims, tuple(args.mask_dims), include_list=np.array([args.mask_pos + args.mask_dims])) attack = AdversarialPatch(mask_gen, args.epsilon, args.iterations, args.optimize_location, args.opt_type, args.stride, args.signed_grad) attack.norm = LInfNorm() objective = UntargetedF0Objective() if args.mode == 'adversarial': trainer = common.train.AdversarialTraining(model, trainset, testset, optimizer, scheduler, attack, objective, fraction=args.adv_frac, augmentation=augmentation, writer=writer, cuda=args.cuda) elif args.mode == 'normal': trainer = common.train.NormalTraining( model, trainset, testset, optimizer, scheduler, augmentation=augmentation, writer=writer, cuda=args.cuda) trainer.summary_gradients = False # Train model for e in range(initial_epoch + 1, args.epochs): trainer.step(e) writer.flush() # Save model snapshot if (e + 1) % args.snapshot_frequency == 0: State.checkpoint(paths.experiment_file( args.models_dir, args.model_prefix + '_' + str(e + 1)), model, optimizer, scheduler, e) # Save final model State.checkpoint(paths.experiment_file( args.models_dir, args.model_prefix + '_complete_' + str(e + 1)), model, optimizer, scheduler, args.epochs)
def loop(self): """ Main loop for training and testing, saving ... """ auto_encoder_params = { 'lr': self.args.base_lr, 'lr_decay': self.args.base_lr_decay, 'lr_min': 0.000000001, 'weight_decay': self.args.weight_decay } classifier_params = { 'lr': self.args.base_lr, 'lr_decay': self.args.base_lr_decay, 'lr_min': 0.000000001, 'weight_decay': self.args.weight_decay } e = 0 if os.path.exists(self.args.encoder_file) and os.path.exists( self.args.decoder_file) and os.path.exists( self.args.classifier_file): state = State.load(self.args.encoder_file) log('[Training] loaded %s' % self.args.encoder_file) self.encoder.load_state_dict(state.model) log('[Training] loaded encoder') if self.args.use_gpu and not cuda.is_cuda(self.encoder): self.encoder = self.encoder.cuda() optimizer = torch.optim.Adam(list(self.encoder.parameters()), auto_encoder_params['lr']) optimizer.load_state_dict(state.optimizer) self.encoder_scheduler = ADAMScheduler(optimizer, **auto_encoder_params) state = State.load(self.args.decoder_file) log('[Training] loaded %s' % self.args.decoder_file) self.decoder.load_state_dict(state.model) log('[Training] loaded decoder') if self.args.use_gpu and not cuda.is_cuda(self.decoder): self.decoder = self.decoder.cuda() optimizer = torch.optim.Adam(list(self.decoder.parameters()), auto_encoder_params['lr']) optimizer.load_state_dict(state.optimizer) self.decoder_scheduler = ADAMScheduler(optimizer, **auto_encoder_params) state = State.load(self.args.classifier_file) log('[Training] loaded %s' % self.args.classifier_file) self.classifier.load_state_dict(state.model) log('[Training] loaded decoder') if self.args.use_gpu and not cuda.is_cuda(self.classifier): self.classifier = self.classifier.cuda() optimizer = torch.optim.Adam(list(self.classifier.parameters()), classifier_params['lr']) optimizer.load_state_dict(state.optimizer) self.classifier_scheduler = ADAMScheduler(optimizer, **classifier_params) e = state.epoch + 1 self.encoder_scheduler.update(e) self.decoder_scheduler.udpate(e) self.classifier_scheduler.update(e) else: if self.args.use_gpu and not cuda.is_cuda(self.encoder): self.encoder = self.encoder.cuda() if self.args.use_gpu and not cuda.is_cuda(self.decoder): self.decoder = self.decoder.cuda() if self.args.use_gpu and not cuda.is_cuda(self.classifier): self.classifier = self.classifier.cuda() self.encoder_scheduler = ADAMScheduler( list(self.encoder.parameters()), **auto_encoder_params) self.encoder_scheduler.initialize() # ! self.decoder_scheduler = ADAMScheduler( list(self.decoder.parameters()), **auto_encoder_params) self.decoder_scheduler.initialize() # ! self.classifier_scheduler = ADAMScheduler( list(self.classifier.parameters()), **classifier_params) self.classifier_scheduler.initialize() # ! log('[Training] model needs %gMiB' % (cuda.estimate_size(self.encoder) / (1024 * 1024))) while e < self.args.epochs: log('[Training] %s' % self.encoder_scheduler.report()) log('[Training] %s' % self.decoder_scheduler.report()) log('[Training] %s' % self.classifier_scheduler.report()) testing = elapsed(functools.partial(self.test, e)) training = elapsed(functools.partial(self.train, e)) log('[Training] %gs training, %gs testing' % (training, testing)) #utils.remove(self.args.encoder_file + '.%d' % (e - 1)) #utils.remove(self.args.decoder_file + '.%d' % (e - 1)) #utils.remove(self.args.classifier_file + '.%d' % (e - 1)) State.checkpoint(self.encoder, self.encoder_scheduler.optimizer, e, self.args.encoder_file + '.%d' % e) State.checkpoint(self.decoder, self.decoder_scheduler.optimizer, e, self.args.decoder_file + '.%d' % e) State.checkpoint(self.classifier, self.classifier_scheduler.optimizer, e, self.args.classifier_file + '.%d' % e) log('[Training] %d: checkpoint' % e) torch.cuda.empty_cache() # necessary? # Save statistics and plots. if self.args.training_file: utils.write_hdf5(self.args.training_file, self.train_statistics) log('[Training] %d: wrote %s' % (e, self.args.training_file)) if self.args.testing_file: utils.write_hdf5(self.args.testing_file, self.test_statistics) log('[Training] %d: wrote %s' % (e, self.args.testing_file)) #if utils.display(): # self.plot() e += 1 # ! testing = elapsed(functools.partial(self.test, e)) log('[Training] %gs testing' % (testing)) #utils.remove(self.args.encoder_file + '.%d' % (e - 1)) #utils.remove(self.args.decoder_file + '.%d' % (e - 1)) #utils.remove(self.args.classifier_file + '.%d' % (e - 1)) State.checkpoint(self.encoder, self.encoder_scheduler.optimizer, e, self.args.encoder_file) State.checkpoint(self.decoder, self.decoder_scheduler.optimizer, e, self.args.decoder_file) State.checkpoint(self.classifier, self.classifier_scheduler.optimizer, e, self.args.classifier_file) self.results = { 'training_statistics': self.train_statistics, 'testing_statistics': self.test_statistics, } if self.args.results_file: utils.write_pickle(self.args.results_file, self.results) log('[Training] wrote %s' % self.args.results_file)