def _test_distrib_integration(device): import numpy as np from ignite.engine import Engine rank = idist.get_rank() n_iters = 80 s = 50 offset = n_iters * s y_true = torch.arange(0, offset * idist.get_world_size(), dtype=torch.float).to(device) y_preds = torch.ones(offset * idist.get_world_size(), dtype=torch.float).to(device) def update(engine, i): return ( y_preds[i * s + offset * rank : (i + 1) * s + offset * rank], y_true[i * s + offset * rank : (i + 1) * s + offset * rank], ) engine = Engine(update) m = MeanAbsoluteError() m.attach(engine, "mae") data = list(range(n_iters)) engine.run(data=data, max_epochs=1) assert "mae" in engine.state.metrics res = engine.state.metrics["mae"] true_res = np.mean(np.abs((y_true - y_preds).cpu().numpy())) assert pytest.approx(res) == true_res
def test_zero_div(): mae = MeanAbsoluteError() with pytest.raises( NotComputableError, match= r"MeanAbsoluteError must have at least one example before it can be computed" ): mae.compute()
def test_accumulator_detached(): mae = MeanAbsoluteError() y_pred = torch.tensor([[2.0], [-2.0]], requires_grad=True) y = torch.zeros(2) mae.update((y_pred, y)) assert not mae._sum_of_absolute_errors.requires_grad
def test_compute(): mae = MeanAbsoluteError() y_pred = torch.Tensor([[2.0], [-2.0]]) y = torch.zeros(2) mae.update((y_pred, y)) assert mae.compute() == 2.0 mae.reset() y_pred = torch.Tensor([[3.0], [-3.0]]) y = torch.zeros(2) mae.update((y_pred, y)) assert mae.compute() == 3.0
def _test(metric_device): engine = Engine(update) m = MeanAbsoluteError(device=metric_device) m.attach(engine, "mae") data = list(range(n_iters)) engine.run(data=data, max_epochs=1) assert "mae" in engine.state.metrics res = engine.state.metrics["mae"] true_res = np.mean(np.abs((y_true - y_preds).cpu().numpy())) assert pytest.approx(res) == true_res
def metrics_selector(mode, loss): mode = mode.lower() if mode == "classification": metrics = { "loss": loss, "accuracy": Accuracy(), "accuracy_topk": TopKCategoricalAccuracy(), "precision": Precision(average=True), "recall": Recall(average=True) } elif mode == "multiclass-multilabel": metrics = { "loss": loss, "accuracy": Accuracy(), } elif mode == "regression": metrics = { "loss": loss, "mse": MeanSquaredError(), "mae": MeanAbsoluteError() } else: raise RuntimeError( "Invalid task mode, select classification or regression") return metrics
def create_sr_evaluator( model, device=None, non_blocking=True, denormalize=True, mean=None, ): # transfer mean to the device and reshape it so # that is is broadcastable to the BCHW format mean = mean.to(device).reshape(1, -1, 1, 1) def denorm_fn(x): return torch.clamp(x + mean, min=0., max=1.) def _evaluate_model(engine, batch): model.eval() x, y = _prepare_batch(batch, device=device, non_blocking=non_blocking) with torch.no_grad(): y_pred = model(x) if denormalize: y_pred, y = map(denorm_fn, [y_pred, y]) return y_pred, y engine = Engine(_evaluate_model) MeanAbsoluteError().attach(engine, 'l1') MeanSquaredError().attach(engine, 'l2') PNSR(max_value=1.0).attach(engine, 'pnsr') return engine
def _test_distrib_accumulator_device(device): metric_devices = [torch.device("cpu")] if device.type != "xla": metric_devices.append(idist.device()) for metric_device in metric_devices: mae = MeanAbsoluteError(device=metric_device) for dev in [mae._device, mae._sum_of_absolute_errors.device]: assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" y_pred = torch.tensor([[2.0], [-2.0]]) y = torch.zeros(2) mae.update((y_pred, y)) for dev in [mae._device, mae._sum_of_absolute_errors.device]: assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"
def test_compute(): mae = MeanAbsoluteError() def _test(y_pred, y, batch_size): mae.reset() if batch_size > 1: n_iters = y.shape[0] // batch_size + 1 for i in range(n_iters): idx = i * batch_size mae.update( (y_pred[idx:idx + batch_size], y[idx:idx + batch_size])) else: mae.update((y_pred, y, batch_size)) np_y = y.numpy() np_y_pred = y_pred.numpy() np_res = (np.abs(np_y_pred - np_y)).sum() / np_y.shape[0] assert isinstance(mae.compute(), float) assert mae.compute() == np_res def get_test_cases(): test_cases = [ (torch.randint(0, 10, size=(100, 1)), torch.randint(0, 10, size=(100, 1)), 1), (torch.randint(-10, 10, size=(100, 5)), torch.randint(-10, 10, size=(100, 5)), 1), # updated batches (torch.randint(0, 10, size=(100, 1)), torch.randint(0, 10, size=(100, 1)), 16), (torch.randint(-20, 20, size=(100, 5)), torch.randint(-20, 20, size=(100, 5)), 16), ] return test_cases for _ in range(5): # check multiple random inputs as random exact occurencies are rare test_cases = get_test_cases() for y_pred, y, batch_size in test_cases: _test(y_pred, y, batch_size)
def _test_distrib_accumulator_device(device): metric_devices = [torch.device("cpu")] if device.type != "xla": metric_devices.append(idist.device()) for metric_device in metric_devices: mae = MeanAbsoluteError(device=metric_device) assert mae._device == metric_device assert mae._sum_of_absolute_errors.device == metric_device, "{}:{} vs {}:{}".format( type(mae._sum_of_absolute_errors.device), mae._sum_of_absolute_errors.device, type(metric_device), metric_device, ) y_pred = torch.tensor([[2.0], [-2.0]]) y = torch.zeros(2) mae.update((y_pred, y)) assert mae._sum_of_absolute_errors.device == metric_device, "{}:{} vs {}:{}".format( type(mae._sum_of_absolute_errors.device), mae._sum_of_absolute_errors.device, type(metric_device), metric_device, )
def __init__(self, output_transform=binary_transform): self._y = None self._pred = None super().__init__(output_transform=output_transform) def reset(self): self._y = list() self._pred = list() super().reset() def update(self, output): y_pred, y = output self._y.append(y) self._pred.append(y_pred) def compute(self): y_pred = torch.cat(self._pred, 0).cpu() y = torch.cat(self._y, 0).cpu() score = f1_score(y, y_pred, average='weighted') return score metric = { 'acc2': Accuracy(output_transform=binary_transform), 'acc5': Accuracy(output_transform=five_transform), 'acc7': Accuracy(output_transform=seven_transform), 'f1': F1(), 'corr': Pearson(), 'mae': MeanAbsoluteError() }
def run(opt): if opt.log_file is not None: logging.basicConfig(filename=opt.log_file, level=logging.INFO) else: logging.basicConfig(level=logging.INFO) logger = logging.getLogger() # logger.addHandler(logging.StreamHandler()) logger = logger.info writer = SummaryWriter(log_dir=opt.log_dir) model_timer, data_timer = Timer(average=True), Timer(average=True) # Training variables logger('Loading models') model, parameters, mean, std = generate_model(opt) optimizer = SGD(parameters, lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay, nesterov=opt.nesterov) scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=opt.lr_patience) # Loading checkpoint if opt.checkpoint: logger('loading checkpoint {}'.format(opt.checkpoint)) checkpoint = torch.load(opt.checkpoint) opt.begin_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) logger('Loading dataset') train_transform = get_transform(mean, std, opt.face_size, mode='training') train_data = get_training_set(opt, transform=train_transform) train_loader = DataLoader(train_data, batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_threads, pin_memory=True) val_transform = get_transform(mean, std, opt.face_size, mode='validation') val_data = get_validation_set(opt, transform=val_transform) val_loader = DataLoader(val_data, batch_size=opt.batch_size, shuffle=False, num_workers=opt.n_threads, pin_memory=True) trainer = create_supervised_trainer(model, optimizer, nn.L1Loss().cuda(), cuda=True) evaluator = create_supervised_evaluator(model, metrics={ 'distance': MeanPairwiseDistance(), 'loss': MeanAbsoluteError() }, cuda=True) # Training timer handlers model_timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) data_timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_COMPLETED, pause=Events.ITERATION_STARTED, step=Events.ITERATION_STARTED) # Training log/plot handlers @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine): iter = (engine.state.iteration - 1) % len(train_loader) + 1 if iter % opt.log_interval == 0: logger( "Epoch[{}] Iteration[{}/{}] Loss: {:.2f} Model Process: {:.3f}s/batch " "Data Preparation: {:.3f}s/batch".format( engine.state.epoch, iter, len(train_loader), engine.state.output, model_timer.value(), data_timer.value())) writer.add_scalar("training/loss", engine.state.output, engine.state.iteration) # Log/Plot Learning rate @trainer.on(Events.EPOCH_STARTED) def log_learning_rate(engine): lr = optimizer.param_groups[0]['lr'] logger('Epoch[{}] Starts with lr={}'.format(engine.state.epoch, lr)) writer.add_scalar("learning_rate", lr, engine.state.epoch) # Checkpointing @trainer.on(Events.EPOCH_COMPLETED) def save_checkpoint(engine): if engine.state.epoch % opt.save_interval == 0: save_file_path = os.path.join( opt.result_path, 'save_{}.pth'.format(engine.state.epoch)) states = { 'epoch': engine.state.epoch, 'arch': opt.model, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), } torch.save(states, save_file_path) # val_evaluator event handlers @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): evaluator.run(val_loader) metrics = evaluator.state.metrics # metric_values = [metrics[m] for m in val_metrics] logger("Validation Results - Epoch: {} ".format(engine.state.epoch) + ' '.join( ['{}: {:.4f}'.format(m, val) for m, val in metrics.items()])) for m, val in metrics.items(): writer.add_scalar('validation/{}'.format(m), val, engine.state.epoch) # if engine.state.epoch == 1: optimizer.param_groups[0]['lr'] = 1e-4 # Update Learning Rate scheduler.step(metrics['loss']) # kick everything off logger('Start training') trainer.run(train_loader, max_epochs=opt.n_epochs) writer.close()
import torch.nn.functional as F from ignite.metrics import CategoricalAccuracy, Loss, MeanAbsoluteError from attributer.attributes import FaceAttributes from training.metric_utils import ScaledError _metrics = { FaceAttributes.AGE: ScaledError(MeanAbsoluteError(), 50), FaceAttributes.GENDER: CategoricalAccuracy(), FaceAttributes.EYEGLASSES: CategoricalAccuracy(), FaceAttributes.RECEDING_HAIRLINES: CategoricalAccuracy(), FaceAttributes.SMILING: CategoricalAccuracy(), FaceAttributes.HEAD_YAW_BIN: CategoricalAccuracy(), FaceAttributes.HEAD_PITCH_BIN: CategoricalAccuracy(), FaceAttributes.HEAD_ROLL_BIN: CategoricalAccuracy(), FaceAttributes.HEAD_YAW: MeanAbsoluteError(), FaceAttributes.HEAD_PITCH: MeanAbsoluteError(), FaceAttributes.HEAD_ROLL: MeanAbsoluteError(), } _losses = { FaceAttributes.AGE: F.l1_loss, FaceAttributes.GENDER: F.cross_entropy, FaceAttributes.EYEGLASSES: F.cross_entropy, FaceAttributes.RECEDING_HAIRLINES: F.cross_entropy, FaceAttributes.SMILING: F.cross_entropy, FaceAttributes.HEAD_YAW_BIN: F.cross_entropy, FaceAttributes.HEAD_PITCH_BIN: F.cross_entropy, FaceAttributes.HEAD_ROLL_BIN: F.cross_entropy, FaceAttributes.HEAD_YAW: F.l1_loss, FaceAttributes.HEAD_PITCH: F.l1_loss,
def run(args, seed): config.make_paths() torch.random.manual_seed(seed) train_loader, val_loader, shape = get_data_loaders( config.Training.batch_size, proportion=config.Training.proportion, test_batch_size=config.Training.batch_size * 2, ) n, d, t = shape model = models.ConvNet(d, seq_len=t) writer = tb.SummaryWriter(log_dir=config.TENSORBOARD) model.to(config.device) # Move model before creating optimizer optimizer = torch.optim.Adam(model.parameters()) criterion = nn.MSELoss() trainer = create_supervised_trainer(model, optimizer, criterion, device=config.device) trainer.logger = setup_logger("trainer") checkpointer = ModelCheckpoint( config.MODEL, model.__class__.__name__, n_saved=2, create_dir=True, save_as_state_dict=True, ) trainer.add_event_handler( Events.EPOCH_COMPLETED(every=config.Training.save_every), checkpointer, {"model": model}, ) val_metrics = { "mse": Loss(criterion), "mae": MeanAbsoluteError(), "rmse": RootMeanSquaredError(), } evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=config.device) evaluator.logger = setup_logger("evaluator") ar_evaluator = create_ar_evaluator(model, metrics=val_metrics, device=config.device) ar_evaluator.logger = setup_logger("ar") @trainer.on(Events.EPOCH_COMPLETED(every=config.Training.save_every)) def log_ar(engine): ar_evaluator.run(val_loader) y_pred, y = ar_evaluator.state.output fig = plot_output(y, y_pred) writer.add_figure("eval/ar", fig, engine.state.epoch) plt.close() # desc = "ITERATION - loss: {:.2f}" # pbar = tqdm(initial=0, leave=False, total=len(train_loader), desc=desc.format(0)) @trainer.on(Events.ITERATION_COMPLETED(every=config.Training.log_every)) def log_training_loss(engine): # pbar.desc = desc.format(engine.state.output) # pbar.update(log_interval) if args.verbose: grad_norm = torch.stack( [p.grad.norm() for p in model.parameters()]).sum() writer.add_scalar("train/grad_norm", grad_norm, engine.state.iteration) writer.add_scalar("train/loss", engine.state.output, engine.state.iteration) @trainer.on(Events.EPOCH_COMPLETED(every=config.Training.eval_every)) def log_training_results(engine): # pbar.refresh() evaluator.run(train_loader) metrics = evaluator.state.metrics for k, v in metrics.items(): writer.add_scalar(f"train/{k}", v, engine.state.epoch) # tqdm.write( # f"Training Results - Epoch: {engine.state.epoch} Avg mse: {evaluator.state.metrics['mse']:.2f}" # ) @trainer.on(Events.EPOCH_COMPLETED(every=config.Training.eval_every)) def log_validation_results(engine): evaluator.run(val_loader) metrics = evaluator.state.metrics for k, v in metrics.items(): writer.add_scalar(f"eval/{k}", v, engine.state.epoch) # tqdm.write( # f"Validation Results - Epoch: {engine.state.epoch} Avg mse: {evaluator.state.metrics['mse']:.2f}" # ) # pbar.n = pbar.last_print_n = 0 y_pred, y = evaluator.state.output fig = plot_output(y, y_pred) writer.add_figure("eval/preds", fig, engine.state.epoch) plt.close() # @trainer.on(Events.EPOCH_COMPLETED | Events.COMPLETED) # def log_time(engine): # #tqdm.write( # # f"{trainer.last_event_name.name} took {trainer.state.times[trainer.last_event_name.name]} seconds" # #) if args.ckpt is not None: ckpt = torch.load(args.ckpt) ModelCheckpoint.load_objects({"model": model}, ckpt) try: trainer.run(train_loader, max_epochs=config.Training.max_epochs) except Exception as e: import traceback print(traceback.format_exc()) # pbar.close() writer.close()
test_dataset = TrainValTestDataset(image_dataset, mode="test") test_loader = DataLoader(dataset=test_dataset, batch_size=args.batchsize, num_workers=num_workers) model = Model(number_of_classes=number_of_classes) optimizer = optim.Adam(model.parameters(), lr=args.learningrate) trainer = create_supervised_trainer(model, optimizer, criterion, device=device) metrics = { "accuracy": Accuracy(), "MAE": MeanAbsoluteError( output_transform=lambda out: (torch.max(out[0], dim=1)[1], out[1])), "MSE": MeanSquaredError( output_transform=lambda out: (torch.max(out[0], dim=1)[1], out[1])), "loss": Loss(loss_fn=criterion) } evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(trainer): print( f"Training (Epoch {trainer.state.epoch}): {trainer.state.output:.3f}")
def test_zero_div(): mae = MeanAbsoluteError() with pytest.raises(NotComputableError): mae.compute()
def train(self, config, **kwargs): config_parameters = parse_config_or_kwargs(config, **kwargs) outputdir = os.path.join( config_parameters['outputpath'], config_parameters['model'], "{}_{}".format( datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%m'), uuid.uuid1().hex)) checkpoint_handler = ModelCheckpoint( outputdir, 'run', n_saved=1, require_empty=False, create_dir=True, score_function=lambda engine: -engine.state.metrics['Loss'], save_as_state_dict=False, score_name='loss') train_kaldi_string = parsecopyfeats( config_parameters['trainfeatures'], **config_parameters['feature_args']) dev_kaldi_string = parsecopyfeats(config_parameters['devfeatures'], **config_parameters['feature_args']) logger = genlogger(os.path.join(outputdir, 'train.log')) logger.info("Experiment is stored in {}".format(outputdir)) for line in pformat(config_parameters).split('\n'): logger.info(line) scaler = getattr( pre, config_parameters['scaler'])(**config_parameters['scaler_args']) inputdim = -1 logger.info("<== Estimating Scaler ({}) ==>".format( scaler.__class__.__name__)) for _, feat in kaldi_io.read_mat_ark(train_kaldi_string): scaler.partial_fit(feat) inputdim = feat.shape[-1] assert inputdim > 0, "Reading inputstream failed" logger.info("Features: {} Input dimension: {}".format( config_parameters['trainfeatures'], inputdim)) logger.info("<== Labels ==>") train_label_df = pd.read_csv( config_parameters['trainlabels']).set_index('Participant_ID') dev_label_df = pd.read_csv( config_parameters['devlabels']).set_index('Participant_ID') train_label_df.index = train_label_df.index.astype(str) dev_label_df.index = dev_label_df.index.astype(str) # target_type = ('PHQ8_Score', 'PHQ8_Binary') target_type = ('PHQ8_Score', 'PHQ8_Binary') n_labels = len(target_type) # PHQ8 + Binary # Scores and their respective PHQ8 train_labels = train_label_df.loc[:, target_type].T.apply( tuple).to_dict() dev_labels = dev_label_df.loc[:, target_type].T.apply(tuple).to_dict() train_dataloader = create_dataloader( train_kaldi_string, train_labels, transform=scaler.transform, shuffle=True, **config_parameters['dataloader_args']) cv_dataloader = create_dataloader( dev_kaldi_string, dev_labels, transform=scaler.transform, shuffle=False, **config_parameters['dataloader_args']) model = getattr(models, config_parameters['model'])( inputdim=inputdim, output_size=n_labels, **config_parameters['model_args']) if 'pretrain' in config_parameters: logger.info("Loading pretrained model {}".format( config_parameters['pretrain'])) pretrained_model = torch.load(config_parameters['pretrain'], map_location=lambda st, loc: st) if 'Attn' in pretrained_model.__class__.__name__: model.lstm.load_state_dict(pretrained_model.lstm.state_dict()) else: model.net.load_state_dict(pretrained_model.net.state_dict()) logger.info("<== Model ==>") for line in pformat(model).split('\n'): logger.info(line) criterion = getattr( losses, config_parameters['loss'])(**config_parameters['loss_args']) optimizer = getattr(torch.optim, config_parameters['optimizer'])( list(model.parameters()) + list(criterion.parameters()), **config_parameters['optimizer_args']) poolingfunction = parse_poolingfunction( config_parameters['poolingfunction']) criterion = criterion.to(device) model = model.to(device) def _train_batch(_, batch): model.train() with torch.enable_grad(): optimizer.zero_grad() outputs, targets = Runner._forward(model, batch, poolingfunction) loss = criterion(outputs, targets) loss.backward() optimizer.step() return loss.item() def _inference(_, batch): model.eval() with torch.no_grad(): return Runner._forward(model, batch, poolingfunction) def meter_transform(output): y_pred, y = output # y_pred is of shape [Bx2] (0 = MSE, 1 = BCE) # y = is of shape [Bx2] (0=Mse, 1 = BCE) return torch.sigmoid(y_pred[:, 1]).round(), y[:, 1].long() precision = Precision(output_transform=meter_transform, average=False) recall = Recall(output_transform=meter_transform, average=False) F1 = (precision * recall * 2 / (precision + recall)).mean() metrics = { 'Loss': Loss(criterion), 'Recall': Recall(output_transform=meter_transform, average=True), 'Precision': Precision(output_transform=meter_transform, average=True), 'MAE': MeanAbsoluteError( output_transform=lambda out: (out[0][:, 0], out[1][:, 0])), 'F1': F1 } train_engine = Engine(_train_batch) inference_engine = Engine(_inference) for name, metric in metrics.items(): metric.attach(inference_engine, name) RunningAverage(output_transform=lambda x: x).attach( train_engine, 'run_loss') pbar = ProgressBar(persist=False) pbar.attach(train_engine, ['run_loss']) scheduler = getattr(torch.optim.lr_scheduler, config_parameters['scheduler'])( optimizer, **config_parameters['scheduler_args']) early_stop_handler = EarlyStopping( patience=5, score_function=lambda engine: -engine.state.metrics['Loss'], trainer=train_engine) inference_engine.add_event_handler(Events.EPOCH_COMPLETED, early_stop_handler) inference_engine.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, { 'model': model, 'scaler': scaler, 'config': config_parameters }) @train_engine.on(Events.EPOCH_COMPLETED) def compute_metrics(engine): inference_engine.run(cv_dataloader) validation_string_list = [ "Validation Results - Epoch: {:<3}".format(engine.state.epoch) ] for metric in metrics: validation_string_list.append("{}: {:<5.2f}".format( metric, inference_engine.state.metrics[metric])) logger.info(" ".join(validation_string_list)) pbar.n = pbar.last_print_n = 0 @inference_engine.on(Events.COMPLETED) def update_reduce_on_plateau(engine): val_loss = engine.state.metrics['Loss'] if 'ReduceLROnPlateau' == scheduler.__class__.__name__: scheduler.step(val_loss) else: scheduler.step() train_engine.run(train_dataloader, max_epochs=config_parameters['epochs']) # Return for further processing return outputdir
def MAEMetric(key): """Create max absolute error metric on key.""" return DictMetric(key, MeanAbsoluteError())
def run( train_batch_size: int, val_batch_size: int, epochs: int, lr: float, model_name: str, architecture: str, momentum: float, log_interval: int, log_dir: str, save_dir: str, save_step: int, val_step: int, num_workers: int, patience: int, eval_only: bool = False, overfit_on_few_samples: bool = False, ): train_loader, val_loader, test_loader = get_data_loaders( train_batch_size, val_batch_size, num_workers=num_workers, overfit_on_few_samples=overfit_on_few_samples, ) models_available = {'convmos': ConvMOS} model = models_available[model_name](architecture=architecture) writer = create_summary_writer(model, train_loader, log_dir) device = 'cpu' if torch.cuda.is_available(): device = 'cuda' model = model.to(device=device) # E-OBS only provides observational data for land so we need to use a mask to avoid fitting on the sea land_mask_np = np.load('remo_eobs_land_mask.npy') # Convert booleans to 1 and 0, and convert numpy array to torch Tensor land_mask = torch.from_numpy(1 * land_mask_np).to(device) print('Land mask:') print(land_mask) loss_fn = partial(masked_mse_loss, mask=land_mask) optimizer = Adam(model.parameters(), lr=lr) trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device) metrics = { 'rmse': RootMeanSquaredError(), 'mae': MeanAbsoluteError(), 'mse': Loss(loss_fn), } train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) val_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) to_save = {'model': model, 'optimizer': optimizer, 'trainer': trainer} checkpoint_handler = Checkpoint( to_save, DiskSaver(save_dir, create_dir=True, require_empty=False), n_saved=2, global_step_transform=global_step_from_engine(trainer), ) trainer.add_event_handler(Events.EPOCH_COMPLETED(every=save_step), checkpoint_handler) trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) def score_function(engine): val_loss = engine.state.metrics['mse'] return -val_loss best_checkpoint_handler = Checkpoint( to_save, DiskSaver(save_dir, create_dir=True, require_empty=False), n_saved=2, filename_prefix='best', score_function=score_function, score_name='val_loss', global_step_transform=global_step_from_engine(trainer), ) val_evaluator.add_event_handler(Events.COMPLETED, best_checkpoint_handler) earlystop_handler = EarlyStopping(patience=patience, score_function=score_function, trainer=trainer) val_evaluator.add_event_handler(Events.COMPLETED, earlystop_handler) # Maybe load model checkpoint_files = glob(join(save_dir, 'checkpoint_*.pt')) if len(checkpoint_files) > 0: # latest_checkpoint_file = sorted(checkpoint_files)[-1] epoch_list = [ int(c.split('.')[0].split('_')[-1]) for c in checkpoint_files ] last_epoch = sorted(epoch_list)[-1] latest_checkpoint_file = join(save_dir, f'checkpoint_{last_epoch}.pt') print('Loading last checkpoint', latest_checkpoint_file) last_epoch = int(latest_checkpoint_file.split('.')[0].split('_')[-1]) if last_epoch >= epochs: print('Training was already completed') eval_only = True # return checkpoint = torch.load(latest_checkpoint_file, map_location=device) Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint) @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine): iter = (engine.state.iteration - 1) % len(train_loader) + 1 if iter % log_interval == 0: print("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}" "".format(engine.state.epoch, iter, len(train_loader), engine.state.output)) writer.add_scalar("training/loss", engine.state.output, engine.state.iteration) @trainer.on(Events.EPOCH_COMPLETED) def log_training_results(engine): train_evaluator.run(train_loader) metrics = train_evaluator.state.metrics avg_rmse = metrics['rmse'] avg_mae = metrics['mae'] avg_mse = metrics['mse'] print( "Training Results - Epoch: {} Avg RMSE: {:.2f} Avg loss: {:.2f} Avg MAE: {:.2f}" .format(engine.state.epoch, avg_rmse, avg_mse, avg_mae)) writer.add_scalar("training/avg_loss", avg_mse, engine.state.epoch) writer.add_scalar("training/avg_rmse", avg_rmse, engine.state.epoch) writer.add_scalar("training/avg_mae", avg_mae, engine.state.epoch) @trainer.on(Events.EPOCH_COMPLETED(every=val_step)) def log_validation_results(engine): val_evaluator.run(val_loader) metrics = val_evaluator.state.metrics avg_rmse = metrics['rmse'] avg_mae = metrics['mae'] avg_mse = metrics['mse'] print( "Validation Results - Epoch: {} Avg RMSE: {:.2f} Avg loss: {:.2f} Avg MAE: {:.2f}" .format(engine.state.epoch, avg_rmse, avg_mse, avg_mae)) writer.add_scalar("validation/avg_loss", avg_mse, engine.state.epoch) writer.add_scalar("validation/avg_rmse", avg_rmse, engine.state.epoch) writer.add_scalar("validation/avg_mae", avg_mae, engine.state.epoch) @trainer.on(Events.EPOCH_COMPLETED(every=save_step)) def log_model_weights(engine): for name, param in model.named_parameters(): writer.add_histogram(f"model/weights_{name}", param, engine.state.epoch) @trainer.on(Events.EPOCH_COMPLETED(every=save_step)) def regularly_predict_val_data(engine): predict_data(engine.state.epoch, val_loader) def predict_data(epoch: int, data_loader) -> xr.Dataset: # Predict all test data points and write the predictions print(f'Predicting {data_loader.dataset.mode} data...') data_loader_iter = iter(data_loader) pred_np = None for i in range(len(data_loader)): x, y = next(data_loader_iter) # print(x) pred = (model.forward(x.to(device=device)).to( device='cpu').detach().numpy()[:, 0, :, :]) # print('=======================================') # print(pred) if pred_np is None: pred_np = pred else: pred_np = np.concatenate((pred_np, pred), axis=0) preds = xr.Dataset( { 'pred': (['time', 'lat', 'lon'], pred_np), 'input': (['time', 'lat', 'lon'], data_loader.dataset.X), 'target': (['time', 'lat', 'lon'], data_loader.dataset.Y[:, :, :, 0]), }, coords={ 'time': data_loader.dataset. times, # list(range(len(val_loader.dataset))), 'lon_var': ( ('lat', 'lon'), data_loader.dataset.lons[0], ), # list(range(x.shape[-2])), 'lat_var': (('lat', 'lon'), data_loader.dataset.lats[0]), }, # list(range(x.shape[-1]))} ) preds.to_netcdf( join(save_dir, f'predictions_{data_loader.dataset.mode}_{epoch}.nc')) return preds # kick everything off if not eval_only: trainer.run(train_loader, max_epochs=epochs) # Load best model best_checkpoint = best_checkpoint_handler.last_checkpoint print('Loading best checkpoint from', best_checkpoint) checkpoint = torch.load(join(save_dir, best_checkpoint_handler.last_checkpoint), map_location=device) Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint) writer.close() val_preds = predict_data(trainer.state.epoch, val_loader) test_preds = predict_data(trainer.state.epoch, test_loader) val_res = mean_metrics(calculate_metrics(val_preds.pred, val_preds.target)) test_res = mean_metrics( calculate_metrics(test_preds.pred, test_preds.target)) # val_evaluator.run(val_loader) results = {} # Store the config, ... results.update({ section_name: dict(config[section_name]) for section_name in config.sections() }) # ... the last training metrics, results.update( {f'train_{k}': v for k, v in train_evaluator.state.metrics.items()}) # ... the last validation metrics from torch, results.update( {f'val_torch_{k}': v for k, v in val_evaluator.state.metrics.items()}) # ... the validation metrics that I calculate, results.update({f'val_{k}': v for k, v in val_res.items()}) # ... asnd the test metrics that I calculate results.update({f'test_{k}': v for k, v in test_res.items()}) write_results_file(join('results', 'results.json'), pd.json_normalize(results))