def __init__(self, trainer, res_dir='results', **kwargs): self.trainer = trainer self.start_datetime = datetime.now().strftime('%Y-%m-%d %H:%M:%S') self.res_dir = Path(res_dir) / self.start_datetime self.res_dir.mkdir(parents=True) metric_loss = Average() metric_loss.attach(self.trainer, 'loss')
def attach_pbar_and_metrics(trainer, evaluator): loss_metric = Average(output_transform=lambda output: output["loss"]) accuracy_metric = Accuracy( output_transform=lambda output: (output["logit"], output["label"])) pbar = ProgressBar() loss_metric.attach(trainer, "loss") accuracy_metric.attach(trainer, "accuracy") accuracy_metric.attach(evaluator, "accuracy") pbar.attach(trainer)
def create_classification_trainer( model, optimizer, loss_fn, device=None, non_blocking=False, prepare_batch=_prepare_batch, output_transform=lambda x, y, y_pred, loss: loss.item()): # noqa """ Factory function for creating a trainer for supervised models. Args: model (`torch.nn.Module`): the model to train. optimizer (`torch.optim.Optimizer`): the optimizer to use. loss_fn (torch.nn loss function): the loss function to use. device (str, optional): device type specification (default: None). Applies to both model and batches. non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs tuple of tensors `(batch_x, batch_y)`. output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`. Note: `engine.state.output` for this engine is defind by `output_transform` parameter and is the loss of the processed batch by default. Returns: Engine: a trainer engine with supervised update function. """ if device: model.to(device) def _update(engine, batch): model.train() optimizer.zero_grad() x, y = prepare_batch(batch, device=device, non_blocking=non_blocking) y_pred = model(x) loss = loss_fn(y_pred, y) loss.backward() optimizer.step() return output_transform(x, y, y_pred, loss) engine = Engine(_update) metric_loss = Average() metric_loss.attach(engine, 'loss') return engine
def main(ctx, config_file, dataset_root, res_root_dir, debug, device, num_workers, **kwargs): with open(config_file) as stream: config = yaml.safe_load(stream) train_transforms = get_transforms(config['train_augment']) val_transforms = get_transforms(config['val_augment']) train_loader, val_loader = get_loaders(train_transforms=train_transforms, val_transforms=val_transforms, dataset_root=dataset_root, num_workers=num_workers, **config['dataset']) label_names = get_labels(train_loader) net, criterion = get_model(n_class=len(label_names), **config['model']) optimizer = get_optimizer(net, **config['optimizer']) trainer = create_supervised_trainer(net, optimizer, criterion, device, prepare_batch=prepare_batch) metric_loss = Average() metric_loss.attach(trainer, 'loss') metrics = get_metrics(label_names, config['evaluate']) metric_names = list(metrics.keys()) evaluator = create_supervised_evaluator(net, metrics, device, prepare_batch=prepare_batch) @trainer.on(Events.EPOCH_COMPLETED) def compute_metrics(engine): evaluator.run(val_loader) res_dir = Path(res_root_dir) / config['dataset']['dataset_name'] train_extend = TrainExtension(trainer, evaluator, res_dir) train_extend.print_metrics(metric_names) train_extend.set_progressbar() train_extend.schedule_lr(optimizer, **config['lr_schedule']) if not debug: train_extend.copy_configs(config_file) train_extend.set_tensorboard(metric_names) train_extend.save_model(net, **config['model_checkpoint']) train_extend.show_config_on_tensorboard(config) trainer.run(train_loader, max_epochs=config['epochs'])
def create_supervised_trainer2(model, optimizer, loss_fn, replay_buffer, device=None, non_blocking=False, prepare_batch=_prepare_batch): "使用 ignite 包,不用 ignite 时,需要修改该函数" from ignite.engine.engine import Engine from ignite.metrics import Average if device: model.to(device) def _update(engine, batch): model.train() LogSumExpf = lambda x: LogSumExp(model(x)) optimizer.zero_grad() x, y = prepare_batch(batch, device=device, non_blocking=non_blocking) x = x.detach() y_pred = model(x) loss_elf = loss_fn(y_pred, y) x_sample = Sample(LogSumExpf, x.shape[0], x.shape[1], replay_buffer, device) replay_buffer.add(x_sample.cpu()) loss_gen = -(LogSumExpf(x) - LogSumExpf(x_sample)).mean() loss = loss_elf + loss_gen loss.backward() optimizer.step() return { 'loss': loss.item(), 'loss_elf': loss_elf.item(), 'loss_gen': loss_gen.item() } engine = Engine(_update) metric_loss = Average(output_transform=lambda output: output['loss']) metric_loss_elf = Average( output_transform=lambda output: output['loss_elf']) metric_loss_gen = Average( output_transform=lambda output: output['loss_gen']) metric_loss.attach(engine, "loss") metric_loss_elf.attach(engine, "loss_elf") metric_loss_gen.attach(engine, "loss_gen") return engine
def main(hparams): results_dir = get_results_directory(hparams.output_dir) writer = SummaryWriter(log_dir=str(results_dir)) ds = get_dataset(hparams.dataset, root=hparams.data_root) input_size, num_classes, train_dataset, test_dataset = ds hparams.seed = set_seed(hparams.seed) if hparams.n_inducing_points is None: hparams.n_inducing_points = num_classes print(f"Training with {hparams}") hparams.save(results_dir / "hparams.json") if hparams.ard: # Hardcoded to WRN output size ard = 640 else: ard = None feature_extractor = WideResNet( spectral_normalization=hparams.spectral_normalization, dropout_rate=hparams.dropout_rate, coeff=hparams.coeff, n_power_iterations=hparams.n_power_iterations, batchnorm_momentum=hparams.batchnorm_momentum, ) initial_inducing_points, initial_lengthscale = initial_values_for_GP( train_dataset, feature_extractor, hparams.n_inducing_points ) gp = GP( num_outputs=num_classes, initial_lengthscale=initial_lengthscale, initial_inducing_points=initial_inducing_points, separate_inducing_points=hparams.separate_inducing_points, kernel=hparams.kernel, ard=ard, lengthscale_prior=hparams.lengthscale_prior, ) model = DKL_GP(feature_extractor, gp) model = model.cuda() likelihood = SoftmaxLikelihood(num_classes=num_classes, mixing_weights=False) likelihood = likelihood.cuda() elbo_fn = VariationalELBO(likelihood, gp, num_data=len(train_dataset)) parameters = [ {"params": feature_extractor.parameters(), "lr": hparams.learning_rate}, {"params": gp.parameters(), "lr": hparams.learning_rate}, {"params": likelihood.parameters(), "lr": hparams.learning_rate}, ] optimizer = torch.optim.SGD( parameters, momentum=0.9, weight_decay=hparams.weight_decay ) milestones = [60, 120, 160] scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=milestones, gamma=0.2 ) def step(engine, batch): model.train() likelihood.train() optimizer.zero_grad() x, y = batch x, y = x.cuda(), y.cuda() y_pred = model(x) elbo = -elbo_fn(y_pred, y) elbo.backward() optimizer.step() return elbo.item() def eval_step(engine, batch): model.eval() likelihood.eval() x, y = batch x, y = x.cuda(), y.cuda() with torch.no_grad(): y_pred = model(x) return y_pred, y trainer = Engine(step) evaluator = Engine(eval_step) metric = Average() metric.attach(trainer, "elbo") def output_transform(output): y_pred, y = output # Sample softmax values independently for classification at test time y_pred = y_pred.to_data_independent_dist() # The mean here is over likelihood samples y_pred = likelihood(y_pred).probs.mean(0) return y_pred, y metric = Accuracy(output_transform=output_transform) metric.attach(evaluator, "accuracy") metric = Loss(lambda y_pred, y: -elbo_fn(y_pred, y)) metric.attach(evaluator, "elbo") kwargs = {"num_workers": 4, "pin_memory": True} train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=hparams.batch_size, shuffle=True, drop_last=True, **kwargs, ) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=512, shuffle=False, **kwargs ) @trainer.on(Events.EPOCH_COMPLETED) def log_results(trainer): metrics = trainer.state.metrics elbo = metrics["elbo"] print(f"Train - Epoch: {trainer.state.epoch} ELBO: {elbo:.2f} ") writer.add_scalar("Likelihood/train", elbo, trainer.state.epoch) if hparams.spectral_normalization: for name, layer in model.feature_extractor.named_modules(): if isinstance(layer, torch.nn.Conv2d): writer.add_scalar( f"sigma/{name}", layer.weight_sigma, trainer.state.epoch ) if not hparams.ard: # Otherwise it's too much to submit to tensorboard length_scales = model.gp.covar_module.base_kernel.lengthscale.squeeze() for i in range(length_scales.shape[0]): writer.add_scalar( f"length_scale/{i}", length_scales[i], trainer.state.epoch ) if trainer.state.epoch > 150 and trainer.state.epoch % 5 == 0: _, auroc, aupr = get_ood_metrics( hparams.dataset, "SVHN", model, likelihood, hparams.data_root ) print(f"OoD Metrics - AUROC: {auroc}, AUPR: {aupr}") writer.add_scalar("OoD/auroc", auroc, trainer.state.epoch) writer.add_scalar("OoD/auprc", aupr, trainer.state.epoch) evaluator.run(test_loader) metrics = evaluator.state.metrics acc = metrics["accuracy"] elbo = metrics["elbo"] print( f"Test - Epoch: {trainer.state.epoch} " f"Acc: {acc:.4f} " f"ELBO: {elbo:.2f} " ) writer.add_scalar("Likelihood/test", elbo, trainer.state.epoch) writer.add_scalar("Accuracy/test", acc, trainer.state.epoch) scheduler.step() pbar = ProgressBar(dynamic_ncols=True) pbar.attach(trainer) trainer.run(train_loader, max_epochs=200) # Done training - time to evaluate results = {} evaluator.run(train_loader) train_acc = evaluator.state.metrics["accuracy"] train_elbo = evaluator.state.metrics["elbo"] results["train_accuracy"] = train_acc results["train_elbo"] = train_elbo evaluator.run(test_loader) test_acc = evaluator.state.metrics["accuracy"] test_elbo = evaluator.state.metrics["elbo"] results["test_accuracy"] = test_acc results["test_elbo"] = test_elbo _, auroc, aupr = get_ood_metrics( hparams.dataset, "SVHN", model, likelihood, hparams.data_root ) results["auroc_ood_svhn"] = auroc results["aupr_ood_svhn"] = aupr print(f"Test - Accuracy {results['test_accuracy']:.4f}") results_json = json.dumps(results, indent=4, sort_keys=True) (results_dir / "results.json").write_text(results_json) torch.save(model.state_dict(), results_dir / "model.pt") torch.save(likelihood.state_dict(), results_dir / "likelihood.pt") writer.close()
def get_engine(): engine = Engine(sum_data) average = Average() average.attach(engine, "average") return engine
def train(model, train_loader, eval_loaders, optimizer, loss_fn, n_it_max, patience, split_names, select_metric='Val accuracy_0', select_mode='max', viz=None, device='cpu', lr_scheduler=None, name=None, log_steps=None, log_epoch=False, _run=None, prepare_batch=_prepare_batch, single_pass=False, n_ep_max=None): # print(model) if not log_steps and not log_epoch: logger.warning('/!\\ No logging during training /!\\') if log_steps is None: log_steps = [] epoch_steps = len(train_loader) if log_epoch: log_steps.append(epoch_steps) if single_pass: max_epoch = 1 elif n_ep_max is None: assert n_it_max is not None max_epoch = int(n_it_max / epoch_steps) + 1 else: assert n_it_max is None max_epoch = n_ep_max all_metrics = defaultdict(dict) trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device, prepare_batch=prepare_batch) if hasattr(model, 'new_epoch_hook'): trainer.add_event_handler(Events.EPOCH_STARTED, model.new_epoch_hook) if hasattr(model, 'new_iter_hook'): trainer.add_event_handler(Events.ITERATION_STARTED, model.new_iter_hook) trainer._logger.setLevel(logging.WARNING) # trainer output is in the format (x, y, y_pred, loss, optionals) train_loss = RunningAverage(output_transform=lambda out: out[3].item(), epoch_bound=True) train_loss.attach(trainer, 'Trainer loss') if hasattr(model, 's'): met = Average(output_transform=lambda _: float('nan') if model.s is None else model.s) met.attach(trainer, 'cur_s') trainer.add_event_handler(Events.ITERATION_COMPLETED, met.completed, 'cur_s') if hasattr(model, 'arch_sampler') and model.arch_sampler.distrib_dim > 0: met = Average(output_transform=lambda _: float('nan') if model.cur_split is None else model.cur_split) met.attach(trainer, 'Trainer split') trainer.add_event_handler(Events.ITERATION_COMPLETED, met.completed, 'Trainer split') # trainer.add_event_handler(Events.EPOCH_STARTED, met.started) all_ent = Average( output_transform=lambda out: out[-1]['arch_entropy_avg'].item()) all_ent.attach(trainer, 'Trainer all entropy') trainer.add_event_handler(Events.ITERATION_COMPLETED, all_ent.completed, 'Trainer all entropy') train_ent = Average( output_transform=lambda out: out[-1]['arch_entropy_sample'].item()) train_ent.attach(trainer, 'Trainer sampling entropy') trainer.add_event_handler(Events.ITERATION_COMPLETED, train_ent.completed, 'Trainer sampling entropy') trainer.add_event_handler( Events.EPOCH_COMPLETED, lambda engine: model.check_arch_freezing( ent=train_ent.compute(), epoch=engine.state.iteration / (epoch_steps * max_epoch))) def log_always(engine, name): val = engine.state.output[-1][name] all_metrics[name][engine.state.iteration / epoch_steps] = val.mean().item() def log_always_dict(engine, name): for node, val in engine.state.output[-1][name].items(): all_metrics['node {} {}'.format( node, name)][engine.state.iteration / epoch_steps] = val.mean().item() trainer.add_event_handler(Events.ITERATION_COMPLETED, log_always_dict, name='arch_grads') trainer.add_event_handler(Events.ITERATION_COMPLETED, log_always_dict, name='arch_probas') trainer.add_event_handler(Events.ITERATION_COMPLETED, log_always_dict, name='node_grads') trainer.add_event_handler(Events.ITERATION_COMPLETED, log_always, name='task all_loss') trainer.add_event_handler(Events.ITERATION_COMPLETED, log_always, name='arch all_loss') trainer.add_event_handler(Events.ITERATION_COMPLETED, log_always, name='entropy all_loss') if n_it_max is not None: StopAfterIterations([n_it_max]).attach(trainer) # epoch_pbar = ProgressBar(bar_format='{l_bar}{bar}{r_bar}', desc=name, # persist=True, disable=not (_run or viz)) # epoch_pbar.attach(trainer, metric_names=['Train loss']) # # training_pbar = ProgressBar(bar_format='{l_bar}{bar}{r_bar}', desc=name, # persist=True, disable=not (_run or viz)) # training_pbar.attach(trainer, event_name=Events.EPOCH_COMPLETED, # closing_event_name=Events.COMPLETED) total_time = Timer(average=False) eval_time = Timer(average=False) eval_time.pause() data_time = Timer(average=False) forward_time = Timer(average=False) forward_time.attach(trainer, start=Events.EPOCH_STARTED, pause=Events.ITERATION_COMPLETED, resume=Events.ITERATION_STARTED, step=Events.ITERATION_COMPLETED) epoch_time = Timer(average=False) epoch_time.attach(trainer, start=Events.EPOCH_STARTED, pause=Events.EPOCH_COMPLETED, resume=Events.EPOCH_STARTED, step=Events.EPOCH_COMPLETED) def get_loss(y_pred, y): l = loss_fn(y_pred, y) if not torch.is_tensor(l): l, *l_details = l return l.mean() def get_member(x, n=0): if isinstance(x, (list, tuple)): return x[n] return x eval_metrics = {'loss': Loss(get_loss)} for i in range(model.n_out): out_trans = get_attr_transform(i) def extract_ys(out): x, y, y_pred, loss, _ = out return out_trans((y_pred, y)) train_acc = Accuracy(extract_ys) train_acc.attach(trainer, 'Trainer accuracy_{}'.format(i)) trainer.add_event_handler(Events.ITERATION_COMPLETED, train_acc.completed, 'Trainer accuracy_{}'.format(i)) eval_metrics['accuracy_{}'.format(i)] = \ Accuracy(output_transform=out_trans) # if isinstance(model, SSNWrapper): # model.arch_sampler.entropy().mean() evaluator = create_supervised_evaluator(model, metrics=eval_metrics, device=device, prepare_batch=prepare_batch) last_iteration = 0 patience_counter = 0 best = { 'value': float('inf') * 1 if select_mode == 'min' else -1, 'iter': -1, 'state_dict': None } def is_better(new, old): if select_mode == 'min': return new < old else: return new > old def log_results(evaluator, data_loader, iteration, split_name): evaluator.run(data_loader) metrics = evaluator.state.metrics log_metrics = {} for metric_name, metric_val in metrics.items(): log_name = '{} {}'.format(split_name, metric_name) if viz: first = iteration == 0 and split_name == split_names[0] viz.line( [metric_val], X=[iteration], win=metric_name, name=log_name, update=None if first else 'append', opts={ 'title': metric_name, 'showlegend': True, 'width': 500, 'xlabel': 'iterations' }) viz.line( [metric_val], X=[iteration / epoch_steps], win='{}epoch'.format(metric_name), name=log_name, update=None if first else 'append', opts={ 'title': metric_name, 'showlegend': True, 'width': 500, 'xlabel': 'epoch' }) if _run: _run.log_scalar(log_name, metric_val, iteration) log_metrics[log_name] = metric_val all_metrics[log_name][iteration] = metric_val return log_metrics if lr_scheduler is not None: @trainer.on(Events.EPOCH_COMPLETED) def step(_): lr_scheduler.step() # logger.warning('current lr {:.5e}'.format( # optimizer.param_groups[0]['lr'])) @trainer.on(Events.ITERATION_COMPLETED) def log_event(trainer): iteration = trainer.state.iteration if trainer.state else 0 nonlocal last_iteration, patience_counter, best if not log_steps or not \ (iteration in log_steps or iteration % log_steps[-1] == 0): return epoch_time.pause() eval_time.resume() all_metrics['training_epoch'][iteration] = iteration / epoch_steps all_metrics['training_iteration'][iteration] = iteration if hasattr(model, 'arch_sampler'): all_metrics['training_archs'][iteration] = \ model.arch_sampler().squeeze().detach() # if hasattr(model, 'distrib_gen'): # entropy = model.distrib_gen.entropy() # all_metrics['entropy'][iteration] = entropy.mean().item() # if trainer.state and len(trainer.state.metrics) > 1: # raise ValueError(trainer.state.metrics) all_metrics['data time'][iteration] = data_time.value() all_metrics['data time_ps'][iteration] = data_time.value() / max( data_time.step_count, 1.) all_metrics['forward time'][iteration] = forward_time.value() all_metrics['forward time_ps'][iteration] = forward_time.value() / max( forward_time.step_count, 1.) all_metrics['epoch time'][iteration] = epoch_time.value() all_metrics['epoch time_ps'][iteration] = epoch_time.value() / max( epoch_time.step_count, 1.) if trainer.state: # logger.warning(trainer.state.metrics) for metric, value in trainer.state.metrics.items(): all_metrics[metric][iteration] = value if viz: viz.line( [value], X=[iteration], win=metric.split()[-1], name=metric, update=None if iteration == 0 else 'append', opts={ 'title': metric, 'showlegend': True, 'width': 500, 'xlabel': 'iterations' }) iter_this_step = iteration - last_iteration for d_loader, name in zip(eval_loaders, split_names): if name == 'Train': if iteration == 0: all_metrics['Trainer loss'][iteration] = float('nan') all_metrics['Trainer accuracy_0'][iteration] = float('nan') if hasattr(model, 'arch_sampler'): all_metrics['Trainer all entropy'][iteration] = float( 'nan') all_metrics['Trainer sampling entropy'][ iteration] = float('nan') # if hasattr(model, 'cur_split'): all_metrics['Trainer split'][iteration] = float('nan') continue split_metrics = log_results(evaluator, d_loader, iteration, name) if select_metric not in split_metrics: continue if is_better(split_metrics[select_metric], best['value']): best['value'] = split_metrics[select_metric] best['iter'] = iteration best['state_dict'] = copy.deepcopy(model.state_dict()) if patience > 0: patience_counter = 0 elif patience > 0: patience_counter += iter_this_step if patience_counter >= patience: logger.info('#####') logger.info('# Early stopping Run') logger.info('#####') trainer.terminate() last_iteration = iteration eval_time.pause() eval_time.step() all_metrics['eval time'][iteration] = eval_time.value() all_metrics['eval time_ps'][iteration] = eval_time.value( ) / eval_time.step_count all_metrics['total time'][iteration] = total_time.value() epoch_time.resume() log_event(trainer) # # @trainer.on(Events.EPOCH_COMPLETED) # def log_epoch(trainer): # iteration = trainer.state.iteration if trainer.state else 0 # epoch = iteration/epoch_steps # fw_t = forward_time.value() # fw_t_ps = fw_t / forward_time.step_count # d_t = data_time.value() # d_t_ps = d_t / data_time.step_count # e_t = epoch_time.value() # e_t_ps = e_t / epoch_time.step_count # ev_t = eval_time.value() # ev_t_ps = ev_t / eval_time.step_count # logger.warning('<{}> Epoch {}/{} finished (Forward: {:.3f}s({:.3f}), ' # 'data: {:.3f}s({:.3f}), epoch: {:.3f}s({:.3f}),' # ' Eval: {:.3f}s({:.3f}), Total: ' # '{:.3f}s)'.format(type(model).__name__, epoch, # max_epoch, fw_t, fw_t_ps, d_t, d_t_ps, # e_t, e_t_ps, ev_t, ev_t_ps, # total_time.value())) data_time.attach(trainer, start=Events.STARTED, pause=Events.ITERATION_STARTED, resume=Events.ITERATION_COMPLETED, step=Events.ITERATION_STARTED) if hasattr(model, 'iter_per_epoch'): model.iter_per_epoch = len(train_loader) trainer.run(train_loader, max_epochs=max_epoch) return trainer.state.iteration, all_metrics, best
def run_training(self): ########## init wandb ########### print(GREEN + "*************** START TRAINING *******************") # fixme check if this works for (key, val) in self.config.items(): print(GREEN + f"{key}: {val}") # print to console wandb.config.update({key: val}) # update wandb config print(GREEN + "**************************************************" + ENDC) ########## checkpoints ########## if self.config["general"]["restart"]: mod_ckpt, op_ckpt = self._load_ckpt("reg_ckpt") # flow_ckpt, flow_op_ckpt = self._load_ckpt("flow_ckpt") else: mod_ckpt = op_ckpt = None dataset, image_transforms = get_dataset(self.config["data"]) transforms = tt.Compose([tt.ToTensor()]) train_dataset = dataset(transforms, data_keys=self.data_keys, mode="train", label_transfer=True, debug=self.config["general"]["debug"], crop_app=True, **self.config["data"]) # if seq_length is pruned, use min seq_length, such that the seq_length of test_dataset lower or equal than that of the train dataset # collect_len = train_dataset.seq_length # self.collect_recon_loss_seq = { # k: np.zeros(shape=[k]) # for k in range(collect_len[0], collect_len[-1]) # } # self.collect_count_seq_lens = np.zeros(shape=[collect_len[-1]]) # # adapt sequence_length # self.config["data"]["seq_length"] = ( # min(self.config["data"]["seq_length"][0], train_dataset.seq_length[0]), # min(self.config["data"]["seq_length"][1], train_dataset.seq_length[1]), # ) train_sampler = RandomSampler(data_source=train_dataset) seq_sampler_train = SequenceSampler( train_dataset, sampler=train_sampler, batch_size=self.config["training"]["batch_size"], drop_last=True, ) train_loader = DataLoader( train_dataset, num_workers=0 if self.config["general"]["debug"] else self.config["data"]["n_data_workers"], batch_sampler=seq_sampler_train, ) # test data t_datakeys = [key for key in self.data_keys] + [ "action", "sample_ids", "intrinsics", "intrinsics_paired", "extrinsics", "extrinsics_paired", ] test_dataset = dataset(image_transforms, data_keys=t_datakeys, mode="test", debug=self.config["general"]["debug"], label_transfer=True, **self.config["data"]) assert (test_dataset.action_id_to_action is not None) rand_sampler_test = RandomSampler(data_source=test_dataset) seq_sampler_test = SequenceSampler( test_dataset, rand_sampler_test, batch_size=self.config["training"]["batch_size"], drop_last=True, ) test_loader = DataLoader( test_dataset, num_workers=0 if self.config["general"]["debug"] else self.config["data"]["n_data_workers"], batch_sampler=seq_sampler_test, ) # rand_sampler_transfer = RandomSampler(data_source=test_dataset) seq_sampler_transfer = SequenceSampler( test_dataset, rand_sampler_transfer, batch_size=1, drop_last=True, ) transfer_loader = DataLoader( test_dataset, batch_sampler=seq_sampler_transfer, num_workers=0 if self.config["general"]["debug"] else self.config["data"]["n_data_workers"], ) # # compare_dataset = dataset( # transforms, # data_keys=t_datakeys, # mode="train", # label_transfer=True, # debug=self.config["general"]["debug"], # crop_app=True, # **self.config["data"] # ) ## Classifier action # n_actions = len(train_dataset.action_id_to_action) # classifier_action = Classifier_action(len(train_dataset.dim_to_use), n_actions, dropout=0, dim=512).to(self.device) # optimizer_classifier = Adam(classifier_action.parameters(), lr=0.0001, weight_decay=1e-4) # print("Number of parameters in classifier action", sum(p.numel() for p in classifier_action.parameters())) # n_actions = len(train_dataset.action_id_to_action) # # classifier_action2 = Classifier_action(len(train_dataset.dim_to_use), n_actions, dropout=0, dim=512).to(self.device) # classifier_action2 = Sequence_disc_michael([2, 1, 1, 1], len(train_dataset.dim_to_use), out_dim=n_actions).to(self.device) # optimizer_classifier2 = Adam(classifier_action2.parameters(), lr=0.0001, weight_decay=1e-5) # print("Number of parameters in classifier action", sum(p.numel() for p in classifier_action2.parameters())) # Classifier beta classifier_beta = Classifier_action_beta(512, n_actions).to(self.device) optimizer_classifier_beta = Adam(classifier_beta.parameters(), lr=0.001) print("Number of parameters in classifier on beta", sum(p.numel() for p in classifier_beta.parameters())) # # Regressor # regressor = Regressor_fly(self.config["architecture"]["dim_hidden_b"], len(train_dataset.dim_to_use)).to(self.device) # optimizer_regressor = Adam(regressor.parameters(), lr=0.0001) # print("Number of parameters in regressor", sum(p.numel() for p in regressor.parameters())) ########## load network and optimizer ########## net = MTVAE(self.config["architecture"], len(train_dataset.dim_to_use), self.device) print( "Number of parameters in VAE model", sum(p.numel() for p in net.parameters()), ) if self.config["general"]["restart"]: if mod_ckpt is not None: print(BLUE + f"***** Initializing VAE from checkpoint! *****" + ENDC) net.load_state_dict(mod_ckpt) net.to(self.device) optimizer = Adam(net.parameters(), lr=self.config["training"]["lr_init"], weight_decay=self.config["training"]["weight_decay"]) wandb.watch(net, log="all", log_freq=len(train_loader)) if self.config["general"]["restart"]: if op_ckpt is not None: optimizer.load_state_dict(op_ckpt) # scheduler = torch.optim.lr_scheduler.MultiStepLR( # optimizer, milestones=self.config["training"]["tau"], gamma=self.config["training"]["gamma"] # ) # rec_loss = nn.MSELoss(reduction="none") ############## DISCRIMINATOR ############################## n_kps = len(train_dataset.dim_to_use) # make gan loss weights print( f"len of train_dataset: {len(train_dataset)}, len of train_loader: {len(train_loader)}" ) # 10 epochs of fine tuning total_steps = (self.config["training"]["n_epochs"] - 10) * len(train_loader) get_kl_weight = partial( linear_var, start_it=0, end_it=total_steps, start_val=1e-5, end_val=1, clip_min=0, clip_max=1, ) def train_fn(engine, batch): net.train() # reference keypoints with label #1 kps = batch["keypoints"].to(torch.float).to(self.device) # keypoints for cross label transfer, label #2 kps_cross = batch["paired_keypoints"].to(torch.float).to( self.device) p_id = batch["paired_sample_ids"].to(torch.int) # reconstruct second sequence with inferred b labels = batch['action'][:, 0] - 2 out_seq, mu, logstd, out_cycle = net(kps, kps_cross) ps = torch.randn_like(out_cycle, requires_grad=False) cycle_loss = torch.mean(torch.abs(out_cycle - ps)) kps_loss = torch.mean(torch.abs(out_seq - kps[:, net.div:])) l_kl = kl_loss(mu, logstd) k_vel = self.config["training"]["k_vel"] vel_tgt = kps[:, net.div:net.div + k_vel] - kps[:, net.div - 1:net.div + k_vel - 1] vel_pred = out_seq[:, :k_vel] - torch.cat( [kps[:, net.div - 1].unsqueeze(1), out_seq[:, :k_vel - 1]], dim=1) motion_loss = torch.mean(torch.abs(vel_tgt - vel_pred)) kl_weight = get_kl_weight(engine.state.iteration) loss = kps_loss + kl_weight * l_kl + self.config["training"]["weight_motion"] * motion_loss \ + self.config["training"]["weight_cycle"] * cycle_loss # # if engine.state.epoch < self.config["training"]["n_epochs"] - 10: optimizer.zero_grad() loss.backward() optimizer.step() out_dict = { "loss": loss.detach().item(), "motion_loss": motion_loss.detach().item(), "rec_loss": kps_loss.detach().item(), "cycle_loss": cycle_loss.detach().item(), "kl_loss": l_kl.detach().item(), "kl_weight": kl_weight } # # # ## Train classifier on action # predict = classifier_action(seq_b)[0] # loss_classifier_action = nn.CrossEntropyLoss()(predict, labels.to(self.device)) # optimizer_classifier.zero_grad() # loss_classifier_action.backward() # optimizer_classifier.step() # _, labels_pred = torch.max(nn.Sigmoid()(predict), dim=1) # acc_action = torch.sum(labels_pred.cpu() == labels).float() / labels_pred.shape[0] # # predict = classifier_action2((seq_b[:, 1:] - seq_b[:, :-1]).transpose(1, 2))[0] # loss_classifier_action2 = nn.CrossEntropyLoss()(predict, labels.to(self.device)) # optimizer_classifier2.zero_grad() # loss_classifier_action2.backward() # optimizer_classifier2.step() # _, labels_pred = torch.max(nn.Sigmoid()(predict), dim=1) # acc_action2 = torch.sum(labels_pred.cpu() == labels).float() / labels_pred.shape[0] # # ## Train classifier on beta # if engine.state.epoch >= self.config["training"]["n_epochs"] - 10: net.eval() with torch.no_grad(): _, mu, *_ = net(kps, kps_cross) predict = classifier_beta(mu) loss_classifier_action_beta = nn.CrossEntropyLoss()( predict, labels.to(self.device)) optimizer_classifier_beta.zero_grad() loss_classifier_action_beta.backward() optimizer_classifier_beta.step() _, labels_pred = torch.max(nn.Sigmoid()(predict), dim=1) acc_action_beta = torch.sum( labels_pred.cpu() == labels).float() / labels_pred.shape[0] # # out_dict = {} # # this is only run if flow training is enable # # # add info to out_dict # out_dict['loss_classifier_action'] = loss_classifier_action.detach().item() # out_dict['acc_classifier_action'] = acc_action.item() # out_dict['loss_classifier_action2'] = loss_classifier_action2.detach().item() # out_dict['acc_classifier_action2'] = acc_action2.item() # # # if engine.state.epoch >= self.config["training"]["n_epochs"] - 10: out_dict[ 'loss_classifier_action_beta'] = loss_classifier_action_beta.detach( ).item() out_dict['acc_action_beta'] = acc_action_beta.item() # out_dict["loss"] = loss.detach().item() # out_dict["kl_loss"] = kl_loss_avg.detach().item() # # out_dict["mu_s"] = torch.mean(mu_s).item() # out_dict["logstd_s"] = torch.mean(logstd_s).item() # # if self.config["training"]["use_regressor"]: # # out_dict["loss_regressor"] = torch.mean(loss_regressor).item() # out_dict["loss_recon"] = recon_loss.detach().item() # out_dict["loss_per_seq_recon"] = ( # recon_loss_per_seq.detach().cpu().numpy() # ) # out_dict["seq_len"] = seq_len # return out_dict ##### CREATE TRAINING RUN ##### trainer = Engine(train_fn) pbar = ProgressBar() pbar.attach( trainer, output_transform=lambda x: {key: x[key] for key in x if "per_seq" not in key}, ) # compute averages for all outputs of train function which are specified in the list # fixme this can be used to log as soon as losses for mtvae are defined and named loss_avg = Average(output_transform=lambda x: x["loss"]) loss_avg.attach(trainer, "loss") recon_loss_avg = Average(output_transform=lambda x: x["rec_loss"]) recon_loss_avg.attach(trainer, "rec_loss") kl_loss_avg = Average(output_transform=lambda x: x["kl_loss"]) kl_loss_avg.attach(trainer, "kl_loss") kl_loss_avg = Average(output_transform=lambda x: x["motion_loss"]) kl_loss_avg.attach(trainer, "motion_loss") kl_loss_avg = Average(output_transform=lambda x: x["cycle_loss"]) kl_loss_avg.attach(trainer, "cycle_loss") # mu_s_avg = Average(output_transform=lambda x: x["mu_s"]) # mu_s_avg.attach(trainer, "mu_s") # logstd_s_avg = Average(output_transform=lambda x: x["logstd_s"]) # logstd_s_avg.attach(trainer, "logstd_s") # # loss_classifier = Average(output_transform=lambda x: x["loss_classifier_action"] if "loss_classifier_action" in x else 0) # loss_classifier.attach(trainer, "loss_classifier_action") # acc_classifier = Average(output_transform=lambda x: x["acc_classifier_action"] if "acc_classifier_action" in x else 0) # acc_classifier.attach(trainer, "acc_classifier_action") # # loss_classifier_action2 = Average(output_transform=lambda x: x["loss_classifier_action2"] if "loss_classifier_action2" in x else 0) # loss_classifier_action2.attach(trainer, "loss_classifier_action2") # acc_classifier_action2 = Average(output_transform=lambda x: x["acc_classifier_action2"] if "acc_classifier_action2" in x else 0) # acc_classifier_action2.attach(trainer, "acc_classifier_action2") # loss_classifier_action_beta = Average( output_transform=lambda x: x["loss_classifier_action_beta"] if "loss_classifier_action_beta" in x else 0) loss_classifier_action_beta.attach(trainer, "loss_classifier_action_beta") acc_action_beta = Average(output_transform=lambda x: x[ "acc_action_beta"] if "acc_action_beta" in x else 0) acc_action_beta.attach(trainer, "acc_action_beta") # loss_avg = Average(output_transform=lambda x: x["loss"]) # loss_avg.attach(trainer, "loss") ##### TRAINING HOOKS ###### # @trainer.on(Events.ITERATION_COMPLETED) # def collect_training_info(engine): # it = engine.state.iteration # # self.collect_recon_loss_seq[seq_len] += engine.state.output[ # "loss_per_seq_recon" # ] # self.collect_count_seq_lens[seq_len] += self.config["training"]["batch_size"] # @trainer.on(Events.EPOCH_COMPLETED) # def update_optimizer_params(engine): # scheduler.step() def log_wandb(engine): wandb.log({ "epoch": engine.state.epoch, "iteration": engine.state.iteration, }) print( f"Logging metrics: Currently, the following metrics are tracked: {list(engine.state.metrics.keys())}" ) for key in engine.state.metrics: val = engine.state.metrics[key] wandb.log({key + "-epoch-avg": val}) print(ENDC + f" [metrics] {key}:{val}") # reset # self.collect_recon_loss_seq = { # k: np.zeros(shape=[k]) # for k in range(collect_len[0], collect_len[-1]) # } # self.collect_count_seq_lens = np.zeros(shape=[collect_len[-1]]) loss_avg = engine.state.metrics["loss"] print(GREEN + f"Epoch {engine.state.epoch} summary:") print(ENDC + f" [losses] loss overall:{loss_avg}") def eval_model(engine): eval_nets(net, test_loader, self.device, engine.state.epoch, cf_action_beta=classifier_beta, debug=self.config["general"]["debug"]) # # def transfer_behavior_test(engine): visualize_transfer3d( net, transfer_loader, self.device, name="Test-Set: ", dirs=self.dirs, revert_coord_space=False, epoch=engine.state.epoch, n_vid_to_generate=self.config["logging"]["n_vid_to_generate"]) # # compare predictions on train and test set # def eval_grid(engine): # if self.config["data"]["dataset"] != "HumanEva": # make_eval_grid( # net, # transfer_loader, # self.device, # dirs=self.dirs, # revert_coord_space=False, # epoch=engine.state.epoch, # synth_ckpt=self.synth_ckpt, # synth_params=self.synth_params, # ) # def latent_interpolations(engine): # latent_interpolate( # net, # transfer_loader, # self.device, # dirs=self.dirs, # epoch=engine.state.epoch, # synth_params=self.synth_params, # synth_ckpt=self.synth_ckpt, # n_vid_to_generate=self.config["logging"]["n_vid_to_generate"] # ) ckpt_handler_reg = ModelCheckpoint(self.dirs["ckpt"], "reg_ckpt", n_saved=100, require_empty=False) save_dict = {"model": net, "optimizer": optimizer} trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1), ckpt_handler_reg, save_dict) trainer.add_event_handler(Events.EPOCH_COMPLETED, log_wandb) def log_outputs(engine): for key in engine.state.output: val = engine.state.output[key] wandb.log({key + "-epoch-step": val}) trainer.add_event_handler( Events.ITERATION_COMPLETED( every=10 if self.config["general"]["debug"] else 1000), log_outputs) trainer.add_event_handler(Events.EPOCH_COMPLETED, eval_model) trainer.add_event_handler( Events.EPOCH_COMPLETED( every=1 if self.config["general"]["debug"] else 3), transfer_behavior_test, ) # trainer.add_event_handler( # Events.EPOCH_COMPLETED( # every=10 # ), # latent_interpolations, # ) # trainer.add_event_handler( # Events.EPOCH_COMPLETED( # every=3 # ), # eval_grid, # ) ####### RUN TRAINING ############## print(BLUE + "*************** Train VAE *******************" + ENDC) trainer.run( train_loader, max_epochs=self.config["training"]["n_epochs"], epoch_length=10 if self.config["general"]["debug"] else len(train_loader), ) print(BLUE + "*************** VAE training ends *******************" + ENDC)
def main( batch_size, epochs, length_scale, centroid_size, model_output_size, learning_rate, l_gradient_penalty, gamma, weight_decay, final_model, ): name = f"DUQ_{length_scale}__{l_gradient_penalty}_{gamma}_{centroid_size}" writer = SummaryWriter(comment=name) ds = all_datasets["CIFAR10"]() input_size, num_classes, dataset, test_dataset = ds # Split up training set idx = list(range(len(dataset))) random.shuffle(idx) if final_model: train_dataset = dataset val_dataset = test_dataset else: val_size = int(len(dataset) * 0.8) train_dataset = torch.utils.data.Subset(dataset, idx[:val_size]) val_dataset = torch.utils.data.Subset(dataset, idx[val_size:]) val_dataset.transform = (test_dataset.transform ) # Test time preprocessing for validation model = ResNet_DUQ( input_size, num_classes, centroid_size, model_output_size, length_scale, gamma, ) model = model.cuda() optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[25, 50, 75], gamma=0.2) def bce_loss_fn(y_pred, y): bce = F.binary_cross_entropy(y_pred, y, reduction="sum").div( num_classes * y_pred.shape[0]) return bce def output_transform_bce(output): y_pred, y, x = output y = F.one_hot(y, num_classes).float() return y_pred, y def output_transform_acc(output): y_pred, y, x = output return y_pred, y def output_transform_gp(output): y_pred, y, x = output return x, y_pred def calc_gradients_input(x, y_pred): gradients = torch.autograd.grad( outputs=y_pred, inputs=x, grad_outputs=torch.ones_like(y_pred), create_graph=True, )[0] gradients = gradients.flatten(start_dim=1) return gradients def calc_gradient_penalty(x, y_pred): gradients = calc_gradients_input(x, y_pred) # L2 norm grad_norm = gradients.norm(2, dim=1) # Two sided penalty gradient_penalty = ((grad_norm - 1)**2).mean() return gradient_penalty def step(engine, batch): model.train() optimizer.zero_grad() x, y = batch x, y = x.cuda(), y.cuda() if l_gradient_penalty > 0: x.requires_grad_(True) z, y_pred = model(x) y = F.one_hot(y, num_classes).float() loss = bce_loss_fn(y_pred, y) if l_gradient_penalty > 0: loss += l_gradient_penalty * calc_gradient_penalty(x, y_pred) loss.backward() optimizer.step() x.requires_grad_(False) with torch.no_grad(): model.eval() model.update_embeddings(x, y) return loss.item() def eval_step(engine, batch): model.eval() x, y = batch x, y = x.cuda(), y.cuda() x.requires_grad_(True) z, y_pred = model(x) return y_pred, y, x trainer = Engine(step) evaluator = Engine(eval_step) metric = Average() metric.attach(trainer, "loss") metric = Accuracy(output_transform=output_transform_acc) metric.attach(evaluator, "accuracy") metric = Loss(F.binary_cross_entropy, output_transform=output_transform_bce) metric.attach(evaluator, "bce") metric = Loss(calc_gradient_penalty, output_transform=output_transform_gp) metric.attach(evaluator, "gradient_penalty") kwargs = {"num_workers": 4, "pin_memory": True} train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1000, shuffle=False, **kwargs) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False, **kwargs) @trainer.on(Events.EPOCH_COMPLETED) def log_results(trainer): metrics = trainer.state.metrics loss = metrics["loss"] print(f"Train - Epoch: {trainer.state.epoch} Loss: {loss:.2f} ") writer.add_scalar("Loss/train", loss, trainer.state.epoch) if trainer.state.epoch % 5 == 0 or trainer.state.epoch > 65: accuracy, auroc = get_cifar_svhn_ood(model) print(f"Test Accuracy: {accuracy}, AUROC: {auroc}") writer.add_scalar("OoD/test_accuracy", accuracy, trainer.state.epoch) writer.add_scalar("OoD/roc_auc", auroc, trainer.state.epoch) accuracy, auroc = get_auroc_classification(val_dataset, model) print(f"AUROC - uncertainty: {auroc}") writer.add_scalar("OoD/val_accuracy", accuracy, trainer.state.epoch) writer.add_scalar("OoD/roc_auc_classification", auroc, trainer.state.epoch) evaluator.run(val_loader) metrics = evaluator.state.metrics acc = metrics["accuracy"] bce = metrics["bce"] GP = metrics["gradient_penalty"] loss = bce + l_gradient_penalty * GP print((f"Valid - Epoch: {trainer.state.epoch} " f"Acc: {acc:.4f} " f"Loss: {loss:.2f} " f"BCE: {bce:.2f} " f"GP: {GP:.2f} ")) writer.add_scalar("Loss/valid", loss, trainer.state.epoch) writer.add_scalar("BCE/valid", bce, trainer.state.epoch) writer.add_scalar("GP/valid", GP, trainer.state.epoch) writer.add_scalar("Accuracy/valid", acc, trainer.state.epoch) print(f"Centroid norm: {torch.norm(model.m / model.N, dim=0)}") scheduler.step() if trainer.state.epoch > 65: torch.save(model.state_dict(), f"saved_models/{name}_{trainer.state.epoch}.pt") pbar = ProgressBar(dynamic_ncols=True) pbar.attach(trainer) trainer.run(train_loader, max_epochs=epochs) evaluator.run(test_loader) acc = evaluator.state.metrics["accuracy"] print(f"Test - Accuracy {acc:.4f}") writer.close()
def main( architecture, batch_size, length_scale, centroid_size, learning_rate, l_gradient_penalty, gamma, weight_decay, final_model, output_dir, ): writer = SummaryWriter(log_dir=f"runs/{output_dir}") ds = all_datasets["CIFAR10"]() input_size, num_classes, dataset, test_dataset = ds # Split up training set idx = list(range(len(dataset))) random.shuffle(idx) if final_model: train_dataset = dataset val_dataset = test_dataset else: val_size = int(len(dataset) * 0.8) train_dataset = torch.utils.data.Subset(dataset, idx[:val_size]) val_dataset = torch.utils.data.Subset(dataset, idx[val_size:]) val_dataset.transform = (test_dataset.transform ) # Test time preprocessing for validation if architecture == "WRN": model_output_size = 640 epochs = 200 milestones = [60, 120, 160] feature_extractor = WideResNet() elif architecture == "ResNet18": model_output_size = 512 epochs = 200 milestones = [60, 120, 160] feature_extractor = resnet18() elif architecture == "ResNet50": model_output_size = 2048 epochs = 200 milestones = [60, 120, 160] feature_extractor = resnet50() elif architecture == "ResNet110": model_output_size = 2048 epochs = 200 milestones = [60, 120, 160] feature_extractor = resnet110() elif architecture == "DenseNet121": model_output_size = 1024 epochs = 200 milestones = [60, 120, 160] feature_extractor = densenet121() # Adapted resnet from: # https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py feature_extractor.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) feature_extractor.maxpool = torch.nn.Identity() feature_extractor.fc = torch.nn.Identity() if centroid_size is None: centroid_size = model_output_size model = ResNet_DUQ( feature_extractor, num_classes, centroid_size, model_output_size, length_scale, gamma, ) model = model.cuda() optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.2) def calc_gradients_input(x, y_pred): gradients = torch.autograd.grad( outputs=y_pred, inputs=x, grad_outputs=torch.ones_like(y_pred), create_graph=True, )[0] gradients = gradients.flatten(start_dim=1) return gradients def calc_gradient_penalty(x, y_pred): gradients = calc_gradients_input(x, y_pred) # L2 norm grad_norm = gradients.norm(2, dim=1) # Two sided penalty gradient_penalty = ((grad_norm - 1)**2).mean() return gradient_penalty def step(engine, batch): model.train() optimizer.zero_grad() x, y = batch x, y = x.cuda(), y.cuda() x.requires_grad_(True) y_pred = model(x) y = F.one_hot(y, num_classes).float() loss = F.binary_cross_entropy(y_pred, y, reduction="mean") if l_gradient_penalty > 0: gp = calc_gradient_penalty(x, y_pred) loss += l_gradient_penalty * gp loss.backward() optimizer.step() x.requires_grad_(False) with torch.no_grad(): model.eval() model.update_embeddings(x, y) return loss.item() def eval_step(engine, batch): model.eval() x, y = batch x, y = x.cuda(), y.cuda() x.requires_grad_(True) y_pred = model(x) return {"x": x, "y": y, "y_pred": y_pred} trainer = Engine(step) evaluator = Engine(eval_step) metric = Average() metric.attach(trainer, "loss") metric = Accuracy(output_transform=lambda out: (out["y_pred"], out["y"])) metric.attach(evaluator, "accuracy") def bce_output_transform(out): return (out["y_pred"], F.one_hot(out["y"], num_classes).float()) metric = Loss(F.binary_cross_entropy, output_transform=bce_output_transform) metric.attach(evaluator, "bce") metric = Loss(calc_gradient_penalty, output_transform=lambda out: (out["x"], out["y_pred"])) metric.attach(evaluator, "gradient_penalty") pbar = ProgressBar(dynamic_ncols=True) pbar.attach(trainer) kwargs = {"num_workers": 4, "pin_memory": True} train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, **kwargs) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, **kwargs) @trainer.on(Events.EPOCH_COMPLETED) def log_results(trainer): metrics = trainer.state.metrics loss = metrics["loss"] print(f"Train - Epoch: {trainer.state.epoch} Loss: {loss:.2f}") writer.add_scalar("Loss/train", loss, trainer.state.epoch) if trainer.state.epoch > (epochs - 5): accuracy, auroc = get_cifar_svhn_ood(model) print(f"Test Accuracy: {accuracy}, AUROC: {auroc}") writer.add_scalar("OoD/test_accuracy", accuracy, trainer.state.epoch) writer.add_scalar("OoD/roc_auc", auroc, trainer.state.epoch) accuracy, auroc = get_auroc_classification(val_dataset, model) print(f"AUROC - uncertainty: {auroc}") writer.add_scalar("OoD/val_accuracy", accuracy, trainer.state.epoch) writer.add_scalar("OoD/roc_auc_classification", auroc, trainer.state.epoch) evaluator.run(val_loader) metrics = evaluator.state.metrics acc = metrics["accuracy"] bce = metrics["bce"] GP = metrics["gradient_penalty"] loss = bce + l_gradient_penalty * GP print((f"Valid - Epoch: {trainer.state.epoch} " f"Acc: {acc:.4f} " f"Loss: {loss:.2f} " f"BCE: {bce:.2f} " f"GP: {GP:.2f} ")) writer.add_scalar("Loss/valid", loss, trainer.state.epoch) writer.add_scalar("BCE/valid", bce, trainer.state.epoch) writer.add_scalar("GP/valid", GP, trainer.state.epoch) writer.add_scalar("Accuracy/valid", acc, trainer.state.epoch) scheduler.step() trainer.run(train_loader, max_epochs=epochs) evaluator.run(test_loader) acc = evaluator.state.metrics["accuracy"] print(f"Test - Accuracy {acc:.4f}") torch.save(model.state_dict(), f"runs/{output_dir}/model.pt") writer.close()
def get_train_mean_std(train_dataset, unique_id="", cache_dir="/tmp/unosat/"): # # Ensure that only process 0 in distributed performs the computation, and the others will use the cache # if dist.get_rank() > 0: # torch.distributed.barrier() # synchronization point for all processes > 0 cache_dir = Path(cache_dir) if not cache_dir.exists(): cache_dir.mkdir(parents=True) if len(unique_id) > 0: unique_id += "_" fp = cache_dir / "train_mean_std_{}{}.pth".format(len(train_dataset), unique_id) if fp.exists(): mean_std = torch.load(fp.as_posix()) else: if dist.is_available() and dist.is_initialized(): raise RuntimeError( "Current implementation of Mean/Std computation is not working in distrib config" ) from ignite.engine import Engine from ignite.metrics import Average from ignite.contrib.handlers import ProgressBar from albumentations.pytorch import ToTensorV2 train_dataset = TransformedDataset(train_dataset, transform_fn=ToTensorV2()) train_loader = DataLoader(train_dataset, shuffle=False, drop_last=False, batch_size=16, num_workers=10, pin_memory=False) def compute_mean_std(engine, batch): b, c, *_ = batch['image'].shape data = batch['image'].reshape(b, c, -1).to(dtype=torch.float64) mean = torch.mean(data, dim=-1) mean2 = torch.mean(data**2, dim=-1) return { "mean": mean, "mean^2": mean2, } compute_engine = Engine(compute_mean_std) ProgressBar(desc="Compute Mean/Std").attach(compute_engine) img_mean = Average(output_transform=lambda output: output['mean']) img_mean2 = Average(output_transform=lambda output: output['mean^2']) img_mean.attach(compute_engine, 'mean') img_mean2.attach(compute_engine, 'mean2') state = compute_engine.run(train_loader) state.metrics['std'] = torch.sqrt(state.metrics['mean2'] - state.metrics['mean']**2) mean_std = {'mean': state.metrics['mean'], 'std': state.metrics['std']} # if dist.get_rank() < 1: torch.save(mean_std, fp.as_posix()) # if dist.get_rank() < 1: # torch.distributed.barrier() # synchronization point for process 0 return mean_std['mean'].tolist(), mean_std['std'].tolist()