def test_disabled_n_saved(dirname): h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=None) engine = Engine(lambda e, b: None) engine.state = State(epoch=0, iteration=0) model = DummyModel() to_save = {"model": model} num_iters = 100 for i in range(num_iters): engine.state.iteration = i h(engine, to_save) saved_files = sorted(os.listdir(dirname)) assert len(saved_files) == num_iters, "{}".format(saved_files) expected = sorted( ["{}_{}_{}.pth".format(_PREFIX, "model", i) for i in range(num_iters)]) assert saved_files == expected, "{} vs {}".format(saved_files, expected)
def test_removes_each_score_at_most_once(dirname): scores = [0, 1, 1, 2, 3] scores_iter = iter(scores) def score_function(_): return next(scores_iter) h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=2, score_function=score_function) engine = Engine(lambda e, b: None) engine.state = State(epoch=0, iteration=0) model = DummyModel() to_save = {"model": model} for _ in range(len(scores)): h(engine, to_save)
def save_model(self, model, save_interval=None, n_saved=1): """Extension method for saving model. This method saves model as a PyTorch model filetype (.pth). Saved file will be saved on `self.res_dir / model / {model_class_name}.pth`. Args: trainer (ignite.Engine): trainer model (torch.nn.Module): model class. save_interval (int): Number of epoch interval in which model should be kept on disk. n_saved (int): Number of objects that should be kept on disk. Older files will be removed. If set to None, all objects are kept. """ if isinstance(model, torch.nn.DataParallel): model = model.module save_handler = ModelCheckpoint(self.res_dir / 'model', model.__class__.__name__, save_interval=save_interval, n_saved=n_saved) self.trainer.add_event_handler(Events.EPOCH_COMPLETED, save_handler, {'epoch': model})
def test_best_k(dirname): scores = iter([1.0, -2., 3.0, -4.0]) def score_function(engine): return next(scores) h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=2, score_function=score_function, save_as_state_dict=False) to_save = {'name': 42} for _ in range(4): h(None, to_save) expected = ['{}_{}_{}.pth'.format(_PREFIX, 'name', i) for i in [1, 3]] assert sorted(os.listdir(dirname)) == expected
def _test_tpu_saves_to_cpu(device, dirname): torch.manual_seed(0) h = ModelCheckpoint(dirname, _PREFIX) engine = Engine(lambda e, b: None) engine.state = State(epoch=0, iteration=1) model = DummyModel().to(device) to_save = {"model": model} h(engine, to_save) idist.barrier() fname = h.last_checkpoint assert isinstance(fname, str) assert os.path.join(dirname, _PREFIX) in fname assert os.path.exists(fname) loaded_objects = torch.load(fname) assert loaded_objects == model.cpu().state_dict()
def warp_common_handler(engine, option, networks_to_save, monitoring_metrics, add_message, use_folder_pathes): # attach progress bar pbar = ProgressBar() pbar.attach(engine, metric_names=monitoring_metrics) timer = Timer(average=True) timer.attach(engine, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) create_plots = make_handle_create_plots(option.output_dir, LOGS_FNAME, PLOT_FNAME) checkpoint_handler = ModelCheckpoint(option.output_dir, CKPT_PREFIX, save_interval=option.save_interval, n_saved=option.n_saved, require_empty=False, create_dir=True, save_as_state_dict=True) engine.add_event_handler(Events.ITERATION_COMPLETED, checkpoint_handler, to_save=networks_to_save) engine.add_event_handler(Events.ITERATION_COMPLETED, create_plots) engine.add_event_handler( Events.EXCEPTION_RAISED, make_handle_handle_exception(checkpoint_handler, networks_to_save, create_plots)) engine.add_event_handler( Events.STARTED, make_handle_make_dirs(option.output_dir, use_folder_pathes)) engine.add_event_handler(Events.STARTED, make_move_html(option.output_dir)) engine.add_event_handler(Events.STARTED, make_create_option_data(option)) engine.add_event_handler(Events.EPOCH_COMPLETED, make_handle_print_times(timer, pbar)) engine.add_event_handler( Events.ITERATION_COMPLETED, make_handle_print_logs(option.output_dir, option.epochs, option.print_freq, pbar, add_message)) return engine
def save_best_model_by_val_score(output_path, evaluator, model, metric_name, n_saved=3, trainer=None, tag="val"): """Method adds a handler to `evaluator` to save best models based on the score (named by `metric_name`) provided by `evaluator`. Args: output_path (str): output path to indicate where to save best models evaluator (Engine): evaluation engine used to provide the score model (nn.Module): model to store metric_name (str): metric name to use for score evaluation. This metric should be present in `evaluator.state.metrics`. n_saved (int, optional): number of best models to store trainer (Engine, optional): trainer engine to fetch the epoch when saving the best model. tag (str, optional): score name prefix: `{tag}_{metric_name}`. By default, tag is "val". Returns: A :class:`~ignite.handlers.checkpoint.ModelCheckpoint` handler. """ global_step_transform = None if trainer is not None: global_step_transform = global_step_from_engine(trainer) best_model_handler = ModelCheckpoint( dirname=output_path, filename_prefix="best", n_saved=n_saved, global_step_transform=global_step_transform, score_name="{}_{}".format(tag, metric_name.lower()), score_function=get_default_score_fn(metric_name), ) evaluator.add_event_handler(Events.COMPLETED, best_model_handler, { "model": model, }) return best_model_handler
def test_best_k_with_suffix(dirname): scores = [0.3456789, 0.1234, 0.4567, 0.134567] scores_iter = iter(scores) def score_function(engine): return next(scores_iter) h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=2, score_function=score_function, score_name="val_loss") engine = Engine(lambda e, b: None) engine.state = State(epoch=0, iteration=0) model = DummyModel() to_save = {'model': model} for _ in range(4): engine.state.epoch += 1 h(engine, to_save) expected = ['{}_{}_val_loss={:.4}.pth'.format(_PREFIX, 'model', scores[e - 1]) for e in [1, 3]] assert sorted(os.listdir(dirname)) == expected
def _test(ext, require_empty, archived): previous_fname = os.path.join(dirname, '{}_{}_{}{}'.format(_PREFIX, 'obj', 1, ext)) with open(previous_fname, 'w') as f: f.write("test") h = ModelCheckpoint(dirname, _PREFIX, create_dir=True, require_empty=require_empty, archived=archived) engine = Engine(lambda e, b: None) engine.state = State(epoch=0, iteration=1) model = DummyModel() to_save = {'model': model} h(engine, to_save) fname = h.last_checkpoint ext = ".pth.tar" if archived else ".pth" assert isinstance(fname, str) assert os.path.join(dirname, '{}_{}_{}{}'.format(_PREFIX, 'model', 1, ext)) == fname assert os.path.exists(fname) assert os.path.exists(previous_fname) loaded_objects = torch.load(fname) assert loaded_objects == model.state_dict() os.remove(fname)
def test_best_k(dirname): scores = iter([1.2, -2.0, 3.1, -4.0]) def score_function(_): return next(scores) h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=2, score_function=score_function) engine = Engine(lambda e, b: None) engine.state = State(epoch=0, iteration=0) model = DummyModel() to_save = {"model": model} for _ in range(4): h(engine, to_save) expected = ["{}_{}_{}.pth".format(_PREFIX, "model", i) for i in [1.2, 3.1]] assert sorted(os.listdir(dirname)) == expected
def test_simple_recovery_from_existing_non_empty(dirname): previous_fname = os.path.join(dirname, '{}_{}_{}.pth'.format(_PREFIX, 'obj', 1)) with open(previous_fname, 'w') as f: f.write("test") h = ModelCheckpoint(dirname, _PREFIX, create_dir=True, require_empty=False) engine = Engine(lambda e, b: None) engine.state = State(epoch=0, iteration=1) model = DummyModel() to_save = {'model': model} h(engine, to_save) fname = h.last_checkpoint assert isinstance(fname, str) assert os.path.join(dirname, '{}_{}_{}.pth'.format(_PREFIX, 'model', 1)) == fname assert os.path.exists(fname) assert os.path.exists(previous_fname) loaded_objects = torch.load(fname) assert "model" in loaded_objects assert loaded_objects['model'] == model.state_dict()
def train(): set_seed(train_param.seed) model = Model(model_param) optimizer = AdamW(model.parameters(), lr=train_param.lr, eps=1e-8) update_steps = train_param.epoch * len(train_loader) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=0, num_training_steps=update_steps) loss_fn = [translate, MSELoss()] device = torch.device(f'cuda:{train_param.device}') trainer = create_trainer(model, optimizer, scheduler, loss_fn, train_param.grad_norm, device) train_evaluator = create_evaluator(model, metric, device) dev_evaluator = create_evaluator(model, metric, device) trainer.add_event_handler( Events.ITERATION_COMPLETED(every=train_param.interval), log_training_loss) trainer.add_event_handler(Events.EPOCH_COMPLETED, log_results, *(train_evaluator, train_loader, 'Train')) trainer.add_event_handler(Events.EPOCH_COMPLETED, log_results, *(dev_evaluator, dev_loader, 'Dev')) es_handler = EarlyStopping(patience=train_param.patience, score_function=score_fn, trainer=trainer) dev_evaluator.add_event_handler(Events.COMPLETED, es_handler) ckpt_handler = ModelCheckpoint(train_param.save_path, '', score_function=score_fn, score_name='score', require_empty=False) dev_evaluator.add_event_handler(Events.COMPLETED, ckpt_handler, { 'model': model, 'param': model_param }) print( f'Start running {train_param.save_path.split("/")[-1]} at device: {train_param.device}\t' f'lr: {train_param.lr}') trainer.run(train_loader, max_epochs=train_param.epoch)
def test_best_k_with_suffix(dirname): scores = [0.3456789, 0.1234, 0.4567, 0.134567] scores_iter = iter(scores) def score_function(engine): return next(scores_iter) h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=2, score_function=score_function, score_name="val_loss") to_save = {'name': 42} for _ in range(4): h(None, to_save) expected = [ '{}_{}_{}_val_loss={:.7}.pth'.format(_PREFIX, 'name', i, scores[i - 1]) for i in [1, 3] ] assert sorted(os.listdir(dirname)) == expected
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_dir): train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size) model = Net() device = "cpu" if torch.cuda.is_available(): device = "cuda" model.to(device) # Move model before creating optimizer optimizer = SGD(model.parameters(), lr=lr, momentum=momentum) criterion = nn.CrossEntropyLoss() trainer = create_supervised_trainer(model, optimizer, criterion, device=device) trainer.logger = setup_logger("Trainer") if sys.version_info > (3, ): from ignite.contrib.metrics.gpu_info import GpuInfo try: GpuInfo().attach(trainer) except RuntimeError: print( "INFO: By default, in this example it is possible to log GPU information (used memory, utilization). " "As there is no pynvml python package installed, GPU information won't be logged. Otherwise, please " "install it : `pip install pynvml`") metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)} train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) train_evaluator.logger = setup_logger("Train Evaluator") validation_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) validation_evaluator.logger = setup_logger("Val Evaluator") @trainer.on(Events.EPOCH_COMPLETED) def compute_metrics(engine): train_evaluator.run(train_loader) validation_evaluator.run(val_loader) tb_logger = TensorboardLogger(log_dir=log_dir) tb_logger.attach_output_handler( trainer, event_name=Events.ITERATION_COMPLETED(every=100), tag="training", output_transform=lambda loss: {"batchloss": loss}, metric_names="all", ) for tag, evaluator in [("training", train_evaluator), ("validation", validation_evaluator)]: tb_logger.attach_output_handler( evaluator, event_name=Events.EPOCH_COMPLETED, tag=tag, metric_names=["loss", "accuracy"], global_step_transform=global_step_from_engine(trainer), ) tb_logger.attach_opt_params_handler( trainer, event_name=Events.ITERATION_COMPLETED(every=100), optimizer=optimizer) tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=100)) tb_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=100)) tb_logger.attach(trainer, log_handler=GradsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=100)) tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=100)) def score_function(engine): return engine.state.metrics["accuracy"] model_checkpoint = ModelCheckpoint( log_dir, n_saved=2, filename_prefix="best", score_function=score_function, score_name="validation_accuracy", global_step_transform=global_step_from_engine(trainer), ) validation_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint, {"model": model}) # kick everything off trainer.run(train_loader, max_epochs=epochs) tb_logger.close()
def train(): config_file = "configs/train_daily_dialog_emotion_action_config.json" config = Config.from_json_file(config_file) # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes logging.basicConfig( level=logging.INFO if config.local_rank in [-1, 0] else logging.WARN) logger.warning( "Running process %d", config.local_rank ) # This is a logger.warning: it will be printed by all distributed processes logger.info("Arguments: %s", pformat(config)) # Initialize distributed training if needed config.distributed = (config.local_rank != -1) if config.distributed: torch.cuda.set_device(config.local_rank) config.device = torch.device("cuda", config.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') logger.info( "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning" ) tokenizer_class = GPT2Tokenizer if "gpt2" in config.model_checkpoint else OpenAIGPTTokenizer tokenizer = tokenizer_class.from_pretrained(config.model_checkpoint) model_class = GPT2DoubleHeadsModel if "gpt2" in config.model_checkpoint else OpenAIGPTDoubleHeadsModel model = model_class.from_pretrained(config.model_checkpoint) tokenizer.set_special_tokens(SPECIAL_TOKENS) model.set_num_special_tokens(len(SPECIAL_TOKENS)) model.to(config.device) optimizer = OpenAIAdam(model.parameters(), lr=config.lr) # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last) if config.fp16: from apex import amp # Apex is only required if we use fp16 training model, optimizer = amp.initialize(model, optimizer, opt_level=config.fp16) if config.distributed: model = DistributedDataParallel(model, device_ids=[config.local_rank], output_device=config.local_rank) logger.info("Prepare datasets") train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders( config, tokenizer) # Training function and trainer def update(engine, batch): model.train() input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids, token_action_ids = tuple( input_tensor.to(config.device) for input_tensor in batch) lm_loss, mc_loss = model(input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids, token_action_ids) loss = (lm_loss * config.lm_coef + mc_loss * config.mc_coef) / config.gradient_accumulation_steps if config.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.max_norm) else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_norm) if engine.state.iteration % config.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() return loss.item() trainer = Engine(update) # Evaluation function and evaluator (evaluator output is the input of the metrics) def inference(engine, batch): model.eval() with torch.no_grad(): batch = tuple( input_tensor.to(config.device) for input_tensor in batch) input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids, token_action_ids = batch #logger.info(tokenizer.decode(input_ids[0, -1, :].tolist())) model_outputs = model(input_ids, mc_token_ids, token_type_ids=token_type_ids, token_emotion_ids=token_emotion_ids, token_action_ids=token_action_ids) lm_logits, mc_logits = model_outputs[0], model_outputs[ 1] # So we can also use GPT2 outputs lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view( -1, lm_logits.size(-1)) lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1) return (lm_logits_flat_shifted, mc_logits), (lm_labels_flat_shifted, mc_labels) evaluator = Engine(inference) # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader)) if config.n_epochs < 1: trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader)) if config.eval_before_start: trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader)) # Make sure distributed data samplers split the dataset nicely between the distributed processes if config.distributed: trainer.add_event_handler( Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch)) evaluator.add_event_handler( Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch)) # Linearly decrease the learning rate from lr to zero scheduler = PiecewiseLinear(optimizer, "lr", [(0, config.lr), (config.n_epochs * len(train_loader), 0.0)]) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # Prepare metrics - note how we compute distributed metrics RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") metrics = { "nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0][0], x[1][0])), "accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1])) } metrics.update({ "average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], config), "average_accuracy": MetricsLambda(average_distributed_scalar, metrics["accuracy"], config) }) metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"]) for name, metric in metrics.items(): metric.attach(evaluator, name) # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train if config.local_rank in [-1, 0]: pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=["loss"]) evaluator.add_event_handler( Events.COMPLETED, lambda _: pbar.log_message( "Validation: %s" % pformat(evaluator.state.metrics))) tb_logger = TensorboardLogger(log_dir=config.log_dir) tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list( metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED) checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir, 'checkpoint', save_interval=1, n_saved=3) trainer.add_event_handler( Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model) }) # "getattr" take care of distributed encapsulation torch.save(config, tb_logger.writer.log_dir + '/model_training_args.bin') getattr(model, 'module', model).config.to_json_file( os.path.join(tb_logger.writer.log_dir, CONFIG_NAME)) tokenizer.save_vocabulary(tb_logger.writer.log_dir) # Run the training trainer.run(train_loader, max_epochs=config.n_epochs) # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method) if config.local_rank in [-1, 0] and config.n_epochs > 0: os.rename( checkpoint_handler._saved[-1][1][-1], os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME) ) # TODO: PR in ignite to have better access to saved file paths (cleaner) tb_logger.close()
def train(model, train_dl, val_dl, test_dl, loss_str, optimizer_name, lr, max_epochs, metrics, val_metric_to_monitor, print_freq, epoch_per_metric, plateau_patience, plateau_terminate, gpu_if_available, gpu_idx, custom_metrics=None, save_dir=None): """Simple model training framework setup with ignite. This builds and runs a standard training process using the ignite framework. Given train/val/test dataloaders, attaches specified metrics and runs over the training set with LR scheduling, early stopping, and model check-pointing all built in. Args: model (nn.Module): A network built in standard PyTorch. train_dl (DataLoader): Train data. val_dl (DataLoader): Val data. test_dl (DataLoader): Test data. optimizer_name (str): Name of the optimizer to use. lr (float): The initial value of the learning rate. loss_str (function): The loss function. max_epochs (int): Max epochs to run the algorithm for. metrics (list): A list of metric strings to be monitored. val_metric_to_monitor (str): The metric to monitor for LR scheduling and early stopping. print_freq (int): Frequency of printing train/val results to console. epoch_per_metric (int): Number of epochs before next computation of val metrics. plateau_patience (int): Number of epochs with no improvement before LR reduction. plateau_terminate (int): Number of epochs with no improvement before stopping. gpu_if_available (bool): Run on the gpu if one exists. gpu_idx (int): The index of the gpu to run on. custom_metrics (dict): Dictionary of custom metrics. save_dir (str): Location to save the model checkpoints. Returns: (results:dict, validation_history:dict): The results of the best model and the full training history. """ device = set_device(gpu_if_available, gpu_idx=gpu_idx) loss_fn = set_loss(loss_str) lr = set_lr(train_dl) if lr is None else lr optimizer = setup_optimizer(model, optimizer_name, lr) # Choose metrics given the string list binary = True if isinstance(loss_fn, torch.nn.BCEWithLogitsLoss) else False metrics, train_metrics, val_metrics = setup_metrics( metrics, loss_fn, binary=binary, custom_metrics=custom_metrics) # Build engines trainer_output_tfm = lambda x, y, y_pred, loss: (loss.item(), y, y_pred) trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device, output_transform=trainer_output_tfm) evaluator = create_supervised_evaluator(model, device=device, metrics=val_metrics) # Attach running average metrics to trainer for name, metric in train_metrics.items(): metric.attach(trainer, name) # Progress bar pbar = tqdm(range(max_epochs)) # Validation loop @trainer.on(Events.EPOCH_COMPLETED) def log_validation_metrics(engine): epoch = engine.state.epoch pbar.update(1) if (epoch % epoch_per_metric == 0) or (epoch == 0): evaluator.run(val_dl, max_epochs=1) add_metrics_to_dict(trainer.state.metrics, validation_history, '.train') add_metrics_to_dict(evaluator.state.metrics, validation_history, '.val') if (epoch % print_freq == 0) or (epoch == 0): print_val_results(epoch, validation_history, pbar=pbar) # Score to monitor for early stopping and check-pointing sign = -1 if val_metric_to_monitor is 'loss' else 1 score_function = lambda engine: engine.state.metrics[val_metric_to_monitor ] * sign # LR scheduling (monitors validation loss), early stopping and check-pointing scheduler = ReduceLROnPlateau(optimizer, patience=plateau_patience, threshold=1e-6, min_lr=1e-7) evaluator.add_event_handler( Events.EPOCH_COMPLETED, lambda engine: scheduler.step(engine.state.metrics['loss'])) # Early stopping stopping = EarlyStopping(patience=plateau_terminate, score_function=score_function, trainer=trainer) evaluator.add_event_handler(Events.EPOCH_COMPLETED, stopping) # Checkpoint save_best_model = ModelCheckpoint(save_dir, '', score_function=score_function) evaluator.add_event_handler(Events.EPOCH_COMPLETED, save_best_model, {'best_model': model}) # History validation_history = OrderedDict() for type in ('train', 'val'): for name in metrics: validation_history[name + '.' + type] = [] # Train the model start, start_memory = time.time(), get_memory(device, reset=True) trainer.run(train_dl, max_epochs=max_epochs) elapsed = time.time() - start memory_usage = get_memory(device) - start_memory # Score on test model.load_state_dict(torch.load(save_best_model.last_checkpoint)) evaluator.run(test_dl, max_epochs=1) # Final model results results = OrderedDict(**{ 'elapsed_time': elapsed, 'memory_usage': memory_usage }) # Best metric/value func = np.argmax if sign == 1 else np.argmin best_idx = func(validation_history[val_metric_to_monitor + '.val']) for key, value in validation_history.items(): results[key] = value[best_idx] for metric, value in evaluator.state.metrics.items(): results[metric + '.test'] = value print_final_results(results) return model, results, validation_history
def main(dataset, dataroot, download, augment, batch_size, eval_batch_size, epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, learn_top, y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers, cuda, n_init_batches, warmup_steps, output_dir, saved_optimizer, warmup, fresh, logittransform, gan, disc_lr, sn, flowgan, eval_every, ld_on_samples, weight_gan, weight_prior, weight_logdet, jac_reg_lambda, affine_eps, no_warm_up, optim_name, clamp, svd_every, eval_only, no_actnorm, affine_scale_eps, actnorm_max_scale, no_conv_actnorm, affine_max_scale, actnorm_eps, init_sample, no_split, disc_arch, weight_entropy_reg, db): check_manual_seed(seed) ds = check_dataset(dataset, dataroot, augment, download) image_shape, num_classes, train_dataset, test_dataset = ds # Note: unsupported for now multi_class = False train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True) test_loader = data.DataLoader(test_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=n_workers, drop_last=False) model = Glow(image_shape, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, num_classes, learn_top, y_condition, logittransform, sn, affine_eps, no_actnorm, affine_scale_eps, actnorm_max_scale, no_conv_actnorm, affine_max_scale, actnorm_eps, no_split) model = model.to(device) if disc_arch == 'mine': discriminator = mine.Discriminator(image_shape[-1]) elif disc_arch == 'biggan': discriminator = cgan_models.Discriminator( image_channels=image_shape[-1], conditional_D=False) elif disc_arch == 'dcgan': discriminator = DCGANDiscriminator(image_shape[0], 64, image_shape[-1]) elif disc_arch == 'inv': discriminator = InvDiscriminator( image_shape, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, num_classes, learn_top, y_condition, logittransform, sn, affine_eps, no_actnorm, affine_scale_eps, actnorm_max_scale, no_conv_actnorm, affine_max_scale, actnorm_eps, no_split) discriminator = discriminator.to(device) D_optimizer = optim.Adam(filter(lambda p: p.requires_grad, discriminator.parameters()), lr=disc_lr, betas=(.5, .99), weight_decay=0) if optim_name == 'adam': optimizer = optim.Adam(model.parameters(), lr=lr, betas=(.5, .99), weight_decay=0) elif optim_name == 'adamax': optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5) if not no_warm_up: lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) iteration_fieldnames = [ 'global_iteration', 'fid', 'sample_pad', 'train_bpd', 'eval_bpd', 'pad', 'batch_real_acc', 'batch_fake_acc', 'batch_acc' ] iteration_logger = CSVLogger(fieldnames=iteration_fieldnames, filename=os.path.join(output_dir, 'iteration_log.csv')) iteration_fieldnames = [ 'global_iteration', 'condition_num', 'max_sv', 'min_sv', 'inverse_condition_num', 'inverse_max_sv', 'inverse_min_sv' ] svd_logger = CSVLogger(fieldnames=iteration_fieldnames, filename=os.path.join(output_dir, 'svd_log.csv')) # test_iter = test_loader.__iter__() N_inception = 1000 x_real_inception = torch.cat([ test_iter.__next__()[0].to(device) for _ in range(N_inception // args.batch_size + 1) ], 0)[:N_inception] x_real_inception = x_real_inception + .5 x_for_recon = test_iter.__next__()[0].to(device) def gan_step(engine, batch): assert not y_condition if 'iter_ind' in dir(engine): engine.iter_ind += 1 else: engine.iter_ind = -1 losses = {} model.train() discriminator.train() x, y = batch x = x.to(device) def run_noised_disc(discriminator, x): x = uniform_binning_correction(x)[0] return discriminator(x) real_acc = fake_acc = acc = 0 if weight_gan > 0: fake = generate_from_noise(model, x.size(0), clamp=clamp) D_real_scores = run_noised_disc(discriminator, x.detach()) D_fake_scores = run_noised_disc(discriminator, fake.detach()) ones_target = torch.ones((x.size(0), 1), device=x.device) zeros_target = torch.zeros((x.size(0), 1), device=x.device) D_real_accuracy = torch.sum( torch.round(F.sigmoid(D_real_scores)) == ones_target).float() / ones_target.size(0) D_fake_accuracy = torch.sum( torch.round(F.sigmoid(D_fake_scores)) == zeros_target).float() / zeros_target.size(0) D_real_loss = F.binary_cross_entropy_with_logits( D_real_scores, ones_target) D_fake_loss = F.binary_cross_entropy_with_logits( D_fake_scores, zeros_target) D_loss = (D_real_loss + D_fake_loss) / 2 gp = gradient_penalty( x.detach(), fake.detach(), lambda _x: run_noised_disc(discriminator, _x)) D_loss_plus_gp = D_loss + 10 * gp D_optimizer.zero_grad() D_loss_plus_gp.backward() D_optimizer.step() # Train generator fake = generate_from_noise(model, x.size(0), clamp=clamp, guard_nans=False) G_loss = F.binary_cross_entropy_with_logits( run_noised_disc(discriminator, fake), torch.ones((x.size(0), 1), device=x.device)) # Trace real_acc = D_real_accuracy.item() fake_acc = D_fake_accuracy.item() acc = .5 * (D_fake_accuracy.item() + D_real_accuracy.item()) z, nll, y_logits, (prior, logdet) = model.forward(x, None, return_details=True) train_bpd = nll.mean().item() loss = 0 if weight_gan > 0: loss = loss + weight_gan * G_loss if weight_prior > 0: loss = loss + weight_prior * -prior.mean() if weight_logdet > 0: loss = loss + weight_logdet * -logdet.mean() if weight_entropy_reg > 0: _, _, _, (sample_prior, sample_logdet) = model.forward(fake, None, return_details=True) # notice this is actually "decreasing" sample likelihood. loss = loss + weight_entropy_reg * (sample_prior.mean() + sample_logdet.mean()) # Jac Reg if jac_reg_lambda > 0: # Sample x_samples = generate_from_noise(model, args.batch_size, clamp=clamp).detach() x_samples.requires_grad_() z = model.forward(x_samples, None, return_details=True)[0] other_zs = torch.cat([ split._last_z2.view(x.size(0), -1) for split in model.flow.splits ], -1) all_z = torch.cat([other_zs, z.view(x.size(0), -1)], -1) sample_foward_jac = compute_jacobian_regularizer(x_samples, all_z, n_proj=1) _, c2, h, w = model.prior_h.shape c = c2 // 2 zshape = (batch_size, c, h, w) randz = torch.randn(zshape).to(device) randz = torch.autograd.Variable(randz, requires_grad=True) images = model(z=randz, y_onehot=None, temperature=1, reverse=True, batch_size=0) other_zs = [split._last_z2 for split in model.flow.splits] all_z = [randz] + other_zs sample_inverse_jac = compute_jacobian_regularizer_manyinputs( all_z, images, n_proj=1) # Data x.requires_grad_() z = model.forward(x, None, return_details=True)[0] other_zs = torch.cat([ split._last_z2.view(x.size(0), -1) for split in model.flow.splits ], -1) all_z = torch.cat([other_zs, z.view(x.size(0), -1)], -1) data_foward_jac = compute_jacobian_regularizer(x, all_z, n_proj=1) _, c2, h, w = model.prior_h.shape c = c2 // 2 zshape = (batch_size, c, h, w) z.requires_grad_() images = model(z=z, y_onehot=None, temperature=1, reverse=True, batch_size=0) other_zs = [split._last_z2 for split in model.flow.splits] all_z = [z] + other_zs data_inverse_jac = compute_jacobian_regularizer_manyinputs( all_z, images, n_proj=1) # loss = loss + jac_reg_lambda * (sample_foward_jac + sample_inverse_jac ) loss = loss + jac_reg_lambda * (sample_foward_jac + sample_inverse_jac + data_foward_jac + data_inverse_jac) if not eval_only: optimizer.zero_grad() loss.backward() if not db: assert max_grad_clip == max_grad_norm == 0 if max_grad_clip > 0: torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip) if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # Replace NaN gradient with 0 for p in model.parameters(): if p.requires_grad and p.grad is not None: g = p.grad.data g[g != g] = 0 optimizer.step() if engine.iter_ind % 100 == 0: with torch.no_grad(): fake = generate_from_noise(model, x.size(0), clamp=clamp) z = model.forward(fake, None, return_details=True)[0] print("Z max min") print(z.max().item(), z.min().item()) if (fake != fake).float().sum() > 0: title = 'NaNs' else: title = "Good" grid = make_grid((postprocess(fake.detach().cpu(), dataset)[:30]), nrow=6).permute(1, 2, 0) plt.figure(figsize=(10, 10)) plt.imshow(grid) plt.axis('off') plt.title(title) plt.savefig( os.path.join(output_dir, f'sample_{engine.iter_ind}.png')) if engine.iter_ind % eval_every == 0: def check_all_zero_except_leading(x): return x % 10**np.floor(np.log10(x)) == 0 if engine.iter_ind == 0 or check_all_zero_except_leading( engine.iter_ind): torch.save( model.state_dict(), os.path.join(output_dir, f'ckpt_sd_{engine.iter_ind}.pt')) model.eval() with torch.no_grad(): # Plot recon fpath = os.path.join(output_dir, '_recon', f'recon_{engine.iter_ind}.png') sample_pad = run_recon_evolution( model, generate_from_noise(model, args.batch_size, clamp=clamp).detach(), fpath) print( f"Iter: {engine.iter_ind}, Recon Sample PAD: {sample_pad}") pad = run_recon_evolution(model, x_for_recon, fpath) print(f"Iter: {engine.iter_ind}, Recon PAD: {pad}") pad = pad.item() sample_pad = sample_pad.item() # Inception score sample = torch.cat([ generate_from_noise(model, args.batch_size, clamp=clamp) for _ in range(N_inception // args.batch_size + 1) ], 0)[:N_inception] sample = sample + .5 if (sample != sample).float().sum() > 0: print("Sample NaNs") raise else: fid = run_fid(x_real_inception.clamp_(0, 1), sample.clamp_(0, 1)) print(f'fid: {fid}, global_iter: {engine.iter_ind}') # Eval BPD eval_bpd = np.mean([ model.forward(x.to(device), None, return_details=True)[1].mean().item() for x, _ in test_loader ]) stats_dict = { 'global_iteration': engine.iter_ind, 'fid': fid, 'train_bpd': train_bpd, 'pad': pad, 'eval_bpd': eval_bpd, 'sample_pad': sample_pad, 'batch_real_acc': real_acc, 'batch_fake_acc': fake_acc, 'batch_acc': acc } iteration_logger.writerow(stats_dict) plot_csv(iteration_logger.filename) model.train() if engine.iter_ind + 2 % svd_every == 0: model.eval() svd_dict = {} ret = utils.computeSVDjacobian(x_for_recon, model) D_for, D_inv = ret['D_for'], ret['D_inv'] cn = float(D_for.max() / D_for.min()) cn_inv = float(D_inv.max() / D_inv.min()) svd_dict['global_iteration'] = engine.iter_ind svd_dict['condition_num'] = cn svd_dict['max_sv'] = float(D_for.max()) svd_dict['min_sv'] = float(D_for.min()) svd_dict['inverse_condition_num'] = cn_inv svd_dict['inverse_max_sv'] = float(D_inv.max()) svd_dict['inverse_min_sv'] = float(D_inv.min()) svd_logger.writerow(svd_dict) # plot_utils.plot_stability_stats(output_dir) # plot_utils.plot_individual_figures(output_dir, 'svd_log.csv') model.train() if eval_only: sys.exit() # Dummy losses['total_loss'] = torch.mean(nll).item() return losses def eval_step(engine, batch): model.eval() x, y = batch x = x.to(device) with torch.no_grad(): if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class, reduction='none') else: z, nll, y_logits = model(x, None) losses = compute_loss(nll, reduction='none') return losses trainer = Engine(gan_step) # else: # trainer = Engine(step) checkpoint_handler = ModelCheckpoint(output_dir, 'glow', save_interval=5, n_saved=1, require_empty=False) trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, { 'model': model, 'optimizer': optimizer }) monitoring_metrics = ['total_loss'] RunningAverage(output_transform=lambda x: x['total_loss']).attach( trainer, 'total_loss') evaluator = Engine(eval_step) # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['total_loss'], torch.empty(x['total_loss'].shape[0]))).attach( evaluator, 'total_loss') if y_condition: monitoring_metrics.extend(['nll']) RunningAverage(output_transform=lambda x: x['nll']).attach( trainer, 'nll') # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['nll'], torch.empty(x['nll'].shape[0]))).attach( evaluator, 'nll') pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) # load pre-trained model if given if saved_model: print("Loading...") print(saved_model) loaded = torch.load(saved_model) # if 'Glow' in str(type(loaded)): # model = loaded # else: # raise # # if 'Glow' in str(type(loaded)): # # loaded = loaded.state_dict() model.load_state_dict(loaded) model.set_actnorm_init() if saved_optimizer: optimizer.load_state_dict(torch.load(saved_optimizer)) file_name, ext = os.path.splitext(saved_model) resume_epoch = int(file_name.split('_')[-1]) @trainer.on(Events.STARTED) def resume_training(engine): engine.state.epoch = resume_epoch engine.state.iteration = resume_epoch * len( engine.state.dataloader) @trainer.on(Events.STARTED) def init(engine): if saved_model: return model.train() print("Initializing Actnorm...") init_batches = [] init_targets = [] if n_init_batches == 0: model.set_actnorm_init() return with torch.no_grad(): if init_sample: generate_from_noise(model, args.batch_size * args.n_init_batches) else: for batch, target in islice(train_loader, None, n_init_batches): init_batches.append(batch) init_targets.append(target) init_batches = torch.cat(init_batches).to(device) assert init_batches.shape[0] == n_init_batches * batch_size if y_condition: init_targets = torch.cat(init_targets).to(device) else: init_targets = None model(init_batches, init_targets) @trainer.on(Events.EPOCH_COMPLETED) def evaluate(engine): evaluator.run(test_loader) if not no_warm_up: scheduler.step() metrics = evaluator.state.metrics losses = ', '.join( [f"{key}: {value:.2f}" for key, value in metrics.items()]) print(f'Validation Results - Epoch: {engine.state.epoch} {losses}') timer = Timer(average=True) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message( f'Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]' ) timer.reset() trainer.run(train_loader, epochs)
def test_simple_recovery(dirname): h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, save_interval=1) h(None, {'obj': 42}) fname = os.path.join(dirname, '{}_{}_{}.pth'.format(_PREFIX, 'obj', 1)) assert torch.load(fname) == 42
# Evaluation metrics = { 'loss': Loss(loss_fn), 'acc': Accuracy() } def score_fn(engine): acc = engine.state.metrics['acc'] return acc evaluator = create_evaluator(model, metrics, device=device) def log_metrics(engine): metrics = evaluator.run(valoader).metrics print('[INFO] Compute metrics...') print(' Validation Results - Average Loss: {:.4f} | Accuracy: {:.4f}'.format(metrics['loss'], metrics['acc'])) print('[INFO] Complete metrics...') trainer.add_event_handler(Events.EPOCH_COMPLETED, log_metrics) # save the model checkpoints saver = ModelCheckpoint(snapshots, 'r101', n_saved=10, score_name='acc', score_function=score_fn) evaluator.add_event_handler(Events.COMPLETED, saver, {'model': model.module}) # start training print('[INFO] Start training...') trainer.run(trainloader, epochs) print('[INFO] Complete training...')
def train(): parser = ArgumentParser() parser.add_argument("--dataset_path", type=str, default="", help="Path or url of the dataset. If empty download from S3.") parser.add_argument("--dataset_cache", type=str, default='./dataset_cache', help="Path or url of the dataset cache") parser.add_argument("--model_checkpoint", type=str, default="openai-gpt", help="Path, url or short name of the model") parser.add_argument("--num_candidates", type=int, default=2, help="Number of candidates for training") parser.add_argument("--max_history", type=int, default=2, help="Number of previous exchanges to keep in history") parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size for training") parser.add_argument("--valid_batch_size", type=int, default=4, help="Batch size for validation") parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Accumulate gradients on several steps") parser.add_argument("--lr", type=float, default=6.25e-5, help="Learning rate") parser.add_argument("--lm_coef", type=float, default=1.0, help="LM loss coefficient") parser.add_argument("--mc_coef", type=float, default=1.0, help="Multiple-choice loss coefficient") parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm") parser.add_argument("--n_epochs", type=int, default=3, help="Number of training epochs") parser.add_argument("--personality_permutations", type=int, default=1, help="Number of permutations of personality sentences") parser.add_argument("--eval_before_start", action='store_true', help="If true start with a first evaluation before training") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") parser.add_argument("--fp16", type=str, default="", help="Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)") parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed training (-1: not distributed)") args = parser.parse_args() # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) logger.warning("Running process %d", args.local_rank) # This is a logger.warning: it will be printed by all distributed processes logger.info("Arguments: %s", pformat(args)) # Initialize distributed training if needed args.distributed = (args.local_rank != -1) if args.distributed: torch.cuda.set_device(args.local_rank) args.device = torch.device("cuda", args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') logger.info("Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning") tokenizer_class = GPT2Tokenizer if "gpt2" in args.model_checkpoint else OpenAIGPTTokenizer tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint) model_class = GPT2LMHeadModel if "gpt2" in args.model_checkpoint else OpenAIGPTLMHeadModel model = model_class.from_pretrained(args.model_checkpoint) tokenizer.set_special_tokens(SPECIAL_TOKENS) model.set_num_special_tokens(len(SPECIAL_TOKENS)) model.to(args.device) optimizer = OpenAIAdam(model.parameters(), lr=args.lr) # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last) if args.fp16: from apex import amp # Apex is only required if we use fp16 training model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16) if args.distributed: model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) logger.info("Prepare datasets") train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(args, tokenizer) # Training function and trainer def update(engine, batch): model.train() batch = tuple(input_tensor.to(args.device) for input_tensor in batch) lm_loss, mc_loss = model(*batch) loss = (lm_loss * args.lm_coef + mc_loss * args.mc_coef) / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm) else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) if engine.state.iteration % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() return loss.item() trainer = Engine(update) # Evaluation function and evaluator (evaluator output is the input of the metrics) def inference(engine, batch): model.eval() with torch.no_grad(): batch = tuple(input_tensor.to(args.device) for input_tensor in batch) input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch logger.info(tokenizer.decode(input_ids[0, -1, :].tolist())) model_outputs = model(input_ids, mc_token_ids, token_type_ids=token_type_ids) lm_logits, mc_logits = model_outputs[0], model_outputs[1] # So we can also use GPT2 outputs lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(-1, lm_logits.size(-1)) lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1) return (lm_logits_flat_shifted, mc_logits), (lm_labels_flat_shifted, mc_labels) evaluator = Engine(inference) # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader)) if args.n_epochs < 1: trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader)) if args.eval_before_start: trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader)) # Make sure distributed data samplers split the dataset nicely between the distributed processes if args.distributed: trainer.add_event_handler(Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch)) evaluator.add_event_handler(Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch)) # Linearly decrease the learning rate from lr to zero scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)]) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # Prepare metrics - note how we compute distributed metrics RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0][0], x[1][0])), "accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))} metrics.update({"average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args), "average_accuracy": MetricsLambda(average_distributed_scalar, metrics["accuracy"], args)}) metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"]) for name, metric in metrics.items(): metric.attach(evaluator, name) # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train if args.local_rank in [-1, 0]: pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=["loss"]) evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Validation: %s" % pformat(evaluator.state.metrics))) tb_logger = TensorboardLogger(log_dir=None) tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED) checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir, 'checkpoint', save_interval=1, n_saved=3) trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model)}) # "getattr" take care of distributed encapsulation torch.save(args, tb_logger.writer.log_dir + '/model_training_args.bin') getattr(model, 'module', model).config.to_json_file(os.path.join(tb_logger.writer.log_dir, CONFIG_NAME)) tokenizer.save_vocabulary(tb_logger.writer.log_dir) # Run the training trainer.run(train_loader, max_epochs=args.n_epochs) # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method) if args.local_rank in [-1, 0] and args.n_epochs > 0: os.rename(checkpoint_handler._saved[-1][1][-1], os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME)) # TODO: PR in ignite to have better access to saved file paths (cleaner) tb_logger.close()
def train_gan(logger: Logger, experiment_dir: Path, data_dir: Path, batch_size: int, z_dim: int, g_filters: int, d_filters: int, learning_rate: float, beta_1: float, epochs: int, saved_g: bool = False, saved_d: bool = False, seed: Optional[int] = None, g_extra_layers: int = 0, d_extra_layers: int = 0, scheduler: bool = False) -> None: seed = fix_random_seed(seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Train started with seed: {seed}") dataset = HDF5ImageDataset(image_dir=data_dir) desired_minkowski = pickle.load( (data_dir / 'minkowski.pkl').open(mode='rb')) loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True) iterations = epochs * len(loader) img_size = dataset.shape[-1] num_channels = dataset.shape[0] # networks net_g = Generator(img_size=img_size, z_dim=z_dim, num_channels=num_channels, num_filters=g_filters, num_extra_layers=g_extra_layers).to(device) net_d = Discriminator(img_size=img_size, num_channels=num_channels, num_filters=d_filters, num_extra_layers=d_extra_layers).to(device) summary(net_g, (z_dim, 1, 1, 1)) summary(net_d, (num_channels, img_size, img_size, img_size)) if saved_g: net_g.load_state_dict(torch.load(experiment_dir / G_CHECKPOINT_NAME)) logger.info("Loaded generator checkpoint") if saved_d: net_d.load_state_dict(torch.load(experiment_dir / D_CHECKPOINT_NAME)) logger.info("Loaded discriminator checkpoint") # criterion criterion = nn.BCELoss() optimizer_g = optim.Adam(net_g.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) optimizer_d = optim.Adam(net_d.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) patience = int(3000 / len(loader)) scheduler_g = optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, min_lr=1e-6, verbose=True, patience=patience) scheduler_d = optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, min_lr=1e-6, verbose=True, patience=patience) # labels smoothing real_labels = torch.full((batch_size, ), fill_value=0.9, device=device) fake_labels = torch.zeros((batch_size, ), device=device) fixed_noise = torch.randn(1, z_dim, 1, 1, 1, device=device) def step(engine: Engine, batch: torch.Tensor) -> Dict[str, float]: """ Train step function :param engine: pytorch ignite train engine :param batch: batch to process :return batch metrics """ # get batch of fake images from generator fake_batch = net_g( torch.randn(batch_size, z_dim, 1, 1, 1, device=device)) # 1. Update D network: maximize log(D(x)) + log(1 - D(G(z))) batch = batch.to(device) optimizer_d.zero_grad() # train D with real and fake batches d_out_real = net_d(batch) d_out_fake = net_d(fake_batch.detach()) loss_d_real = criterion(d_out_real, real_labels) loss_d_fake = criterion(d_out_fake, fake_labels) # mean probabilities p_real = d_out_real.mean().item() p_fake = d_out_fake.mean().item() loss_d = (loss_d_real + loss_d_fake) / 2 loss_d.backward() optimizer_d.step() # 2. Update G network: maximize log(D(G(z))) loss_g = None p_gen = None for _ in range(1): fake_batch = net_g( torch.randn(batch_size, z_dim, 1, 1, 1, device=device)) optimizer_g.zero_grad() d_out_fake = net_d(fake_batch) loss_g = criterion(d_out_fake, real_labels) # mean fake generator probability p_gen = d_out_fake.mean().item() loss_g.backward() optimizer_g.step() # minkowski functional measures cube = net_g(fixed_noise).detach().squeeze().cpu() cube = cube.mul(0.5).add(0.5).numpy() cube = postprocess_cube(cube) cube = np.pad(cube, ((1, 1), (1, 1), (1, 1)), mode='constant', constant_values=0) v, s, b, xi = compute_minkowski(cube) return { 'loss_d': loss_d.item(), 'loss_g': loss_g.item(), 'p_real': p_real, 'p_fake': p_fake, 'p_gen': p_gen, 'V': v, 'S': s, 'B': b, 'Xi': xi } # ignite objects trainer = Engine(step) checkpoint_handler = ModelCheckpoint(dirname=str(experiment_dir), filename_prefix=CKPT_PREFIX, save_interval=5, n_saved=50, require_empty=False) # attach running average metrics monitoring_metrics = [ 'loss_d', 'loss_g', 'p_real', 'p_fake', 'p_gen', 'V', 'S', 'B', 'Xi' ] RunningAverage(alpha=ALPHA, output_transform=lambda x: x['loss_d']).attach( trainer, 'loss_d') RunningAverage(alpha=ALPHA, output_transform=lambda x: x['loss_g']).attach( trainer, 'loss_g') RunningAverage(alpha=ALPHA, output_transform=lambda x: x['p_real']).attach( trainer, 'p_real') RunningAverage(alpha=ALPHA, output_transform=lambda x: x['p_fake']).attach( trainer, 'p_fake') RunningAverage(alpha=ALPHA, output_transform=lambda x: x['p_gen']).attach( trainer, 'p_gen') RunningAverage(alpha=ALPHA, output_transform=lambda x: x['V']).attach(trainer, 'V') RunningAverage(alpha=ALPHA, output_transform=lambda x: x['S']).attach(trainer, 'S') RunningAverage(alpha=ALPHA, output_transform=lambda x: x['B']).attach(trainer, 'B') RunningAverage(alpha=ALPHA, output_transform=lambda x: x['Xi']).attach(trainer, 'Xi') # attach progress bar pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) @trainer.on(Events.ITERATION_COMPLETED) def print_logs(engine): if (engine.state.iteration - 1) % PRINT_FREQ == 0: fname = experiment_dir / LOGS_FNAME columns = ['iter'] + list(engine.state.metrics.keys()) values = [str(engine.state.iteration)] + [ str(round(value, 7)) for value in engine.state.metrics.values() ] with fname.open(mode='a') as f: if f.tell() == 0: print('\t'.join(columns), file=f) print('\t'.join(values), file=f) message = f"[{engine.state.epoch}/{epochs}][{engine.state.iteration:04d}/{iterations}]" for name, value in zip(engine.state.metrics.keys(), engine.state.metrics.values()): message += f" | {name}: {value:0.5f}" pbar.log_message(message) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ 'net_g': net_g, 'net_d': net_d }) @trainer.on(Events.EPOCH_COMPLETED) def create_plots(engine): df = pd.read_csv(experiment_dir / LOGS_FNAME, delimiter='\t') fig_1 = plt.figure(figsize=(18, 12)) plt.plot(df['iter'], df['loss_d'], label='loss_d', linestyle='dashed') plt.plot(df['iter'], df['loss_g'], label='loss_g') plt.xlabel('Iteration number') plt.legend() fig_1.savefig(experiment_dir / ('loss_' + PLOT_FNAME)) plt.close(fig_1) fig_2 = plt.figure(figsize=(18, 12)) plt.plot(df['iter'], df['p_real'], label='p_real', linestyle='dashed') plt.plot(df['iter'], df['p_fake'], label='p_fake', linestyle='dashdot') plt.plot(df['iter'], df['p_gen'], label='p_gen') plt.xlabel('Iteration number') plt.legend() fig_2.savefig(experiment_dir / PLOT_FNAME) plt.close(fig_2) desired_v = [desired_minkowski[0]] * len(df['iter']) desired_s = [desired_minkowski[1]] * len(df['iter']) desired_b = [desired_minkowski[2]] * len(df['iter']) desired_xi = [desired_minkowski[3]] * len(df['iter']) fig_3 = plt.figure(figsize=(18, 12)) plt.plot(df['iter'], df['V'], label='V', color='b') plt.plot(df['iter'], desired_v, color='b', linestyle='dashed') plt.xlabel('Iteration number') plt.ylabel('Minkowski functional V') plt.legend() fig_3.savefig(experiment_dir / ('minkowski_V_' + PLOT_FNAME)) plt.close(fig_3) fig_4 = plt.figure(figsize=(18, 12)) plt.plot(df['iter'], df['S'], label='S', color='r') plt.plot(df['iter'], desired_s, color='r', linestyle='dashed') plt.xlabel('Iteration number') plt.ylabel('Minkowski functional S') plt.legend() fig_4.savefig(experiment_dir / ('minkowski_S_' + PLOT_FNAME)) plt.close(fig_4) fig_5 = plt.figure(figsize=(18, 12)) plt.plot(df['iter'], df['B'], label='B', color='g') plt.plot(df['iter'], desired_b, color='g', linestyle='dashed') plt.xlabel('Iteration number') plt.ylabel('Minkowski functional B') plt.legend() fig_5.savefig(experiment_dir / ('minkowski_B_' + PLOT_FNAME)) plt.close(fig_5) fig_6 = plt.figure(figsize=(18, 12)) plt.plot(df['iter'], df['Xi'], label='Xi', color='y') plt.plot(df['iter'], desired_xi, color='y', linestyle='dashed') plt.xlabel('Iteration number') plt.ylabel('Minkowski functional Xi') plt.legend() fig_6.savefig(experiment_dir / ('minkowski_Xi_' + PLOT_FNAME)) plt.close(fig_6) if scheduler: @trainer.on(Events.EPOCH_COMPLETED) def lr_scheduler(engine): desired_b = desired_minkowski[2] desired_xi = desired_minkowski[3] current_b = engine.state.metrics['B'] current_xi = engine.state.metrics['Xi'] delta = abs(desired_b - current_b) + abs(desired_xi - current_xi) scheduler_d.step(delta) scheduler_g.step(delta) @trainer.on(Events.EXCEPTION_RAISED) def handle_exception(engine, e): if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1): engine.terminate() warnings.warn('KeyboardInterrupt caught. Exiting gracefully.') create_plots(engine) checkpoint_handler(engine, { 'net_g_exception': net_g, 'net_d_exception': net_d }) else: raise e trainer.run(loader, epochs)
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ images = [ "/workspace/data/medical/ixi/IXI-T1/IXI314-IOP-0889-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI249-Guys-1072-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI609-HH-2600-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI173-HH-1590-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI020-Guys-0700-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI342-Guys-0909-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI134-Guys-0780-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI577-HH-2661-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI066-Guys-0731-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI130-HH-1528-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz", ] # 2 binary labels for gender classification: man and woman labels = np.array( [0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0]) train_files = [{ "img": img, "label": label } for img, label in zip(images[:10], labels[:10])] val_files = [{ "img": img, "label": label } for img, label in zip(images[-10:], labels[-10:])] # define transforms for image train_transforms = Compose([ LoadNiftid(keys=["img"]), AddChanneld(keys=["img"]), ScaleIntensityd(keys=["img"]), Resized(keys=["img"], spatial_size=(96, 96, 96)), RandRotate90d(keys=["img"], prob=0.8, spatial_axes=[0, 2]), ToTensord(keys=["img"]), ]) val_transforms = Compose([ LoadNiftid(keys=["img"]), AddChanneld(keys=["img"]), ScaleIntensityd(keys=["img"]), Resized(keys=["img"], spatial_size=(96, 96, 96)), ToTensord(keys=["img"]), ]) # define dataset, data loader check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) check_data = monai.utils.misc.first(check_loader) print(check_data["img"].shape, check_data["label"]) # create DenseNet121, CrossEntropyLoss and Adam optimizer net = monai.networks.nets.densenet.densenet121( spatial_dims=3, in_channels=1, out_channels=2, ) loss = torch.nn.CrossEntropyLoss() lr = 1e-5 opt = torch.optim.Adam(net.parameters(), lr) device = torch.device("cuda:0") # Ignite trainer expects batch=(img, label) and returns output=loss at every iteration, # user can add output_transform to return other values, like: y_pred, y, etc. def prepare_batch(batch, device=None, non_blocking=False): return _prepare_batch((batch["img"], batch["label"]), device, non_blocking) trainer = create_supervised_trainer(net, opt, loss, device, False, prepare_batch=prepare_batch) # adding checkpoint handler to save models (network params and optimizer stats) during training checkpoint_handler = ModelCheckpoint("./runs/", "net", n_saved=10, require_empty=False) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ "net": net, "opt": opt }) # StatsHandler prints loss at every iteration and print metrics at every epoch, # we don't set metrics for trainer here, so just print loss, user can also customize print functions # and can use output_transform to convert engine.state.output if it's not loss value train_stats_handler = StatsHandler(name="trainer") train_stats_handler.attach(trainer) # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler train_tensorboard_stats_handler = TensorBoardStatsHandler() train_tensorboard_stats_handler.attach(trainer) # set parameters for validation validation_every_n_epochs = 1 metric_name = "Accuracy" # add evaluation metric to the evaluator engine val_metrics = { metric_name: Accuracy(), "AUC": ROCAUC(to_onehot_y=True, add_softmax=True) } # Ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration, # user can add output_transform to return other values evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) # add stats event handler to print validation stats via evaluator val_stats_handler = StatsHandler( name="evaluator", output_transform=lambda x: None, # no need to print loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.epoch, ) # fetch global epoch number from trainer val_stats_handler.attach(evaluator) # add handler to record metrics to TensorBoard at every epoch val_tensorboard_stats_handler = TensorBoardStatsHandler( output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.epoch, ) # fetch global epoch number from trainer val_tensorboard_stats_handler.attach(evaluator) # add early stopping handler to evaluator early_stopper = EarlyStopping( patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer) evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) def run_validation(engine): evaluator.run(val_loader) # create a training data loader train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, pin_memory=torch.cuda.is_available()) train_epochs = 30 state = trainer.run(train_loader, train_epochs)
def add_handlers(trainer, evaluator, data, models_dict, cfg): """ :param trainer: ignite trainer object :param evaluator: ignite evaluator object :param data: tuple containing train and test dataloader :param models_dict: dict containing all models & optimizers to save :param cfg: configuration dict """ train_loader, test_loader = data # add progressbar progbar = ProgressBar(trainer, train_loader) trainer.add_event_handler(event_name=Events.ITERATION_COMPLETED, handler=progbar) # initialize checkpoint savings funtion checkpoint = ModelCheckpoint(cfg.DIRS.CHKP_DIR, cfg.DIRS.CHKP_PREFIX, require_empty=False, save_interval=1, n_saved=100000, save_as_state_dict=True) writer = Tensorboard(create_dir(cfg.DIRS.CHKP_DIR, 'summaries')) trainer.add_event_handler(event_name=Events.ITERATION_COMPLETED, handler=writer) evaluator.add_event_handler(event_name=Events.ITERATION_COMPLETED, handler=writer) # if models were loaded, resume training from left off epoch # otherwise start at epoch 0. @trainer.on(Events.STARTED) def epoch_start(engine): if (cfg.SOLVER.RESUME_EPOCH != '0') and \ (cfg.SOLVER.COMPLETED_EPOCHS != 0): mess = 'TRAINING COMPLETE' if \ ((cfg.SOLVER.COMPLETED_EPOCHS - cfg.SOLVER.EPOCHS) >= 0) \ else 'RESUME TRAINING' engine.state.iteration = cfg.SOLVER.TRAINER_ITERATION engine.state.epoch = checkpoint._iteration = cfg.SOLVER.COMPLETED_EPOCHS print(' --- LOADED MODEL FOR EPOCH: {comp_epochs} / {conf} ---\ \n --------- {mess} ---------'.format( comp_epochs=cfg.SOLVER.COMPLETED_EPOCHS, conf=cfg.SOLVER.EPOCHS, mess=mess)) @trainer.on(Events.EPOCH_COMPLETED) def save_models(engine): cfg.SOLVER.COMPLETED_EPOCHS = engine.state.epoch cfg.SOLVER.TRAINER_ITERATION = engine.state.iteration if cfg.SOLVER.COMPLETED_EPOCHS % cfg.MODEL.SAVE_INTERVAL == 0: # checkpoint only counts nr of checkpoint calls, not epochs checkpoint._iteration = cfg.SOLVER.COMPLETED_EPOCHS - 1 checkpoint(engine, models_dict) save_ignite_params(engine, engine_name='trainer', cfg=cfg) @trainer.on(Events.EPOCH_COMPLETED) def classification_validation(engine): cfg.RESULTS.LATENTS, cfg.RESULTS.CLF_ACC, cfg.RESULTS.MEAN_DISTANCE, \ cfg.RESULTS.SMOOTHNESS, cfg.RESULTS.CLUSTER_ACC = [], [], [], [], [] print('--- Evaluating model on validation set ---') evaluator.run(test_loader) clf_acc_str = calc_mean_non_empty('clf_acc', cfg.RESULTS.CLF_ACC) cluster_acc_str = calc_mean_non_empty('cluster_acc', cfg.RESULTS.CLUSTER_ACC) mean_dist_str = calc_mean_non_empty('mean_distance', cfg.RESULTS.MEAN_DISTANCE) smoothness_str = calc_mean_non_empty('smoothness', cfg.RESULTS.SMOOTHNESS) print('{clf}{cluster}{mean_dist}{smooth}'.format( clf=clf_acc_str, cluster=cluster_acc_str, mean_dist=mean_dist_str, smooth=smoothness_str)) @evaluator.on(Events.STARTED) def continue_validation(engine): # create dict keys at first epoch if 'EVAL_ITERATION' not in cfg.SOLVER.keys(): cfg.SOLVER.EVAL_ITERATION = 0 cfg.SOLVER.EVAL_EPOCH = 0 else: # after each iteration, dict gets updated # needed so that it continues counter at start of # every eval run (bc for every validation run # the evaluator is newly initialized) engine.state.iteration = cfg.SOLVER.EVAL_ITERATION # always put to 0 to run evaluation once # can't specify max epochs bc it would run # validation set several times after each other engine.state.epoch = 0 @evaluator.on(Events.EPOCH_COMPLETED) def save_eval_state(engine): # save iteration after each validation run # continue at this counter for next run # (as number otherwise resets), # save number of eval epochs for saving/loading params cfg.SOLVER.EVAL_ITERATION = engine.state.iteration cfg.SOLVER.EVAL_EPOCH += 1 if cfg.SOLVER.COMPLETED_EPOCHS % cfg.MODEL.SAVE_INTERVAL == 0: save_ignite_params(engine, engine_name='eval', cfg=cfg) return trainer, evaluator
def run(args): train_loader, val_loader = get_data_loaders(args.dir, args.batch_size, args.num_workers) if args.seed is not None: torch.manual_seed(args.seed) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') num_classes = CityscapesDataset.num_instance_classes() + 1 model = models.box2pix(num_classes=num_classes) model.init_from_googlenet() writer = create_summary_writer(model, train_loader, args.log_dir) if torch.cuda.device_count() > 1: print("Using %d GPU(s)" % torch.cuda.device_count()) model = nn.DataParallel(model) model = model.to(device) semantics_criterion = nn.CrossEntropyLoss(ignore_index=255) offsets_criterion = nn.MSELoss() box_criterion = BoxLoss(num_classes, gamma=2) multitask_criterion = MultiTaskLoss().to(device) box_coder = BoxCoder() optimizer = optim.Adam([{ 'params': model.parameters(), 'weight_decay': 5e-4 }, { 'params': multitask_criterion.parameters() }], lr=args.lr) if args.resume: if os.path.isfile(args.resume): print("Loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) multitask_criterion.load_state_dict(checkpoint['multitask']) print("Loaded checkpoint '{}' (Epoch {})".format( args.resume, checkpoint['epoch'])) else: print("No checkpoint found at '{}'".format(args.resume)) def _prepare_batch(batch, non_blocking=True): x, instance, boxes, labels = batch return (convert_tensor(x, device=device, non_blocking=non_blocking), convert_tensor(instance, device=device, non_blocking=non_blocking), convert_tensor(boxes, device=device, non_blocking=non_blocking), convert_tensor(labels, device=device, non_blocking=non_blocking)) def _update(engine, batch): model.train() optimizer.zero_grad() x, instance, boxes, labels = _prepare_batch(batch) boxes, labels = box_coder.encode(boxes, labels) loc_preds, conf_preds, semantics_pred, offsets_pred = model(x) semantics_loss = semantics_criterion(semantics_pred, instance) offsets_loss = offsets_criterion(offsets_pred, instance) box_loss, conf_loss = box_criterion(loc_preds, boxes, conf_preds, labels) loss = multitask_criterion(semantics_loss, offsets_loss, box_loss, conf_loss) loss.backward() optimizer.step() return { 'loss': loss.item(), 'loss_semantics': semantics_loss.item(), 'loss_offsets': offsets_loss.item(), 'loss_ssdbox': box_loss.item(), 'loss_ssdclass': conf_loss.item() } trainer = Engine(_update) checkpoint_handler = ModelCheckpoint(args.output_dir, 'checkpoint', save_interval=1, n_saved=10, require_empty=False, create_dir=True, save_as_state_dict=False) timer = Timer(average=True) # attach running average metrics train_metrics = [ 'loss', 'loss_semantics', 'loss_offsets', 'loss_ssdbox', 'loss_ssdclass' ] for m in train_metrics: transform = partial(lambda x, metric: x[metric], metric=m) RunningAverage(output_transform=transform).attach(trainer, m) # attach progress bar pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=train_metrics) checkpoint = { 'model': model.state_dict(), 'epoch': trainer.state.epoch, 'optimizer': optimizer.state_dict(), 'multitask': multitask_criterion.state_dict() } trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'checkpoint': checkpoint}) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) def _inference(engine, batch): model.eval() with torch.no_grad(): x, instance, boxes, labels = _prepare_batch(batch) loc_preds, conf_preds, semantics, offsets_pred = model(x) boxes_preds, labels_preds, scores_preds = box_coder.decode( loc_preds, F.softmax(conf_preds, dim=1), score_thresh=0.01) semantics_loss = semantics_criterion(semantics, instance) offsets_loss = offsets_criterion(offsets_pred, instance) box_loss, conf_loss = box_criterion(loc_preds, boxes, conf_preds, labels) semantics_pred = semantics.argmax(dim=1) instances = helper.assign_pix2box(semantics_pred, offsets_pred, boxes_preds, labels_preds) return { 'loss': (semantics_loss, offsets_loss, { 'box_loss': box_loss, 'conf_loss': conf_loss }), 'objects': (boxes_preds, labels_preds, scores_preds, boxes, labels), 'semantics': semantics_pred, 'instances': instances } train_evaluator = Engine(_inference) Loss(multitask_criterion, output_transform=lambda x: x['loss']).attach(train_evaluator, 'loss') MeanAveragePrecision(num_classes, output_transform=lambda x: x['objects']).attach( train_evaluator, 'objects') IntersectionOverUnion(num_classes, output_transform=lambda x: x['semantics']).attach( train_evaluator, 'semantics') evaluator = Engine(_inference) Loss(multitask_criterion, output_transform=lambda x: x['loss']).attach(evaluator, 'loss') MeanAveragePrecision(num_classes, output_transform=lambda x: x['objects']).attach( evaluator, 'objects') IntersectionOverUnion(num_classes, output_transform=lambda x: x['semantics']).attach( evaluator, 'semantics') @trainer.on(Events.STARTED) def initialize(engine): if args.resume: engine.state.epoch = args.start_epoch @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message( "Epoch [{}/{}] done. Time per batch: {:.3f}[s]".format( engine.state.epoch, engine.state.max_epochs, timer.value())) timer.reset() @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine): iteration = (engine.state.iteration - 1) % len(train_loader) + 1 if iteration % args.log_interval == 0: writer.add_scalar("training/loss", engine.state.output['loss'], engine.state.iteration) @trainer.on(Events.EPOCH_COMPLETED) def log_training_results(engine): train_evaluator.run(train_loader) metrics = train_evaluator.state.metrics loss = metrics['loss'] mean_ap = metrics['objects'] iou = metrics['semantics'] pbar.log_message( 'Training results - Epoch: [{}/{}]: Loss: {:.4f}, mAP(50%): {:.1f}, IoU: {:.1f}' .format(loss, evaluator.state.epochs, evaluator.state.max_epochs, mean_ap, iou * 100.0)) writer.add_scalar("train-val/loss", loss, engine.state.epoch) writer.add_scalar("train-val/mAP", mean_ap, engine.state.epoch) writer.add_scalar("train-val/IoU", iou, engine.state.epoch) @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): evaluator.run(val_loader) metrics = evaluator.state.metrics loss = metrics['loss'] mean_ap = metrics['objects'] iou = metrics['semantics'] pbar.log_message( 'Validation results - Epoch: [{}/{}]: Loss: {:.4f}, mAP(50%): {:.1f}, IoU: {:.1f}' .format(loss, evaluator.state.epochs, evaluator.state.max_epochs, mean_ap, iou * 100.0)) writer.add_scalar("validation/loss", loss, engine.state.epoch) writer.add_scalar("validation/mAP", mean_ap, engine.state.epoch) writer.add_scalar("validation/IoU", iou, engine.state.epoch) @trainer.on(Events.EXCEPTION_RAISED) def handle_exception(engine, e): if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1): engine.terminate() warnings.warn("KeyboardInterrupt caught. Exiting gracefully.") checkpoint_handler(engine, {'model_exception': model}) else: raise e @trainer.on(Events.COMPLETED) def save_final_model(engine): checkpoint_handler(engine, {'final': model}) trainer.run(train_loader, max_epochs=args.epochs) writer.close()
def train(): logger.info('*' * 64) logger.info('token:%s' % current_time) logger.info('*' * 64) parser = ArgumentParser() parser.add_argument( "--train_file", type=str, default="./my_test/data/student/part1.txt", help="Path or url of the dataset. If empty download from S3.") parser.add_argument("--dataset_cache", type=str, default='./cache/', help="Path or url of the dataset cache") parser.add_argument("--batch_size", type=int, default=2, help="Batch size for validation") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Accumulate gradients on several steps") parser.add_argument("--lr", type=float, default=6.25e-4, help="Learning rate") # parser.add_argument("--train_precent", type=float, default=0.7, help="Batch size for validation") parser.add_argument("--n_epochs", type=int, default=1, help="Number of training epochs") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") # parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm") parser.add_argument("--log_step", type=int, default=1, help="Multiple-choice loss coefficient") parser.add_argument("--base_model", type=str, default="bert-base-uncased") parser.add_argument( "--on_memory", action='store_true', help="Whether to load train samples into memory or use disk") parser.add_argument( "--max_seq_length", default=128, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument( "--do_lower_case", action='store_true', help= "Whether to lower case the input text. True for uncased models, False for cased models." ) args = parser.parse_args() logger.info(args) device = torch.device(args.device) tokenizer = BertTokenizer.from_pretrained(args.base_model) train_dataset = BERTDataset(args.train_file, tokenizer, seq_len=args.max_seq_length, corpus_lines=None, on_memory=args.on_memory) train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size) model = BertForPreTraining.from_pretrained(args.base_model) optimizer = optim.Adam(model.parameters(), lr=args.lr) steps = len(train_data_loader.dataset) // train_data_loader.batch_size steps = steps if steps > 0 else 1 logger.info('steps:%d' % steps) lr_warmup = get_cosine_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=1500, num_training_steps=steps * args.n_epochs) if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") gpu_num = torch.cuda.device_count() gpu_list = [int(i) for i in range(gpu_num)] model = DataParallel(model, device_ids=gpu_list) multi_gpu = True if torch.cuda.is_available(): model.cuda() # model.to(device) # criterion.to(device) def update(engine, batch): model.train() # input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch """ input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, masked_lm_labels=None, next_sentence_label=None, """ # loss = model(input_ids=batch[0],input_mask=batch[1],segment_ids=batch[2],lm_label_ids=batch[3],is_next=batch[4]) loss = model(input_ids=batch[0], attention_mask=batch[1], position_ids=batch[2], masked_lm_labels=batch[3], next_sentence_label=batch[4]) if engine.state.iteration % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() lr_warmup.step() if multi_gpu: loss = loss.mean() loss.backward() return loss.cpu().item() trainer = Engine(update) # def inference(engine, batch): # model.eval() # with torch.no_grad(): # input_ids = batch[0].to(device) # attention_mask = batch[1].to(device) # labels = batch[2].to(device) # output = model(input_ids=input_ids, attention_mask=attention_mask) # # predict = output.permute(1, 2, 0) # trg = labels.permute(1, 0) # loss = criterion(predict.to(device), trg.to(device)) # return predict, trg # # evaluator = Engine(inference) # metrics = {"nll": Loss(criterion, output_transform=lambda x: (x[0], x[1])), # "accuracy": Accuracy(output_transform=lambda x: (x[0], x[1]))} # for name, metric in metrics.items(): # metric.attach(evaluator, name) # # @trainer.on(Events.EPOCH_COMPLETED) # def log_validation_results(trainer): # evaluator.run(valid_data_loader) # ms = evaluator.state.metrics # logger.info("Validation Results - Epoch: [{}/{}] Avg accuracy: {:.6f} Avg loss: {:.6f}" # .format(trainer.state.epoch, trainer.state.max_epochs, ms['accuracy'], ms['nll'])) # '''======================early stopping ==========================''' # def score_function(engine): # val_loss = engine.state.metrics['nll'] # return -val_loss # handler = EarlyStopping(patience=5, score_function=score_function, trainer=trainer) # evaluator.add_event_handler(Events.COMPLETED, handler) '''==================print information by iterator=========================''' @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(trainer): if trainer.state.iteration % args.log_step == 0: logger.info("Epoch[{}/{}] Step[{}/{}] Loss: {:.6f}".format( trainer.state.epoch, trainer.state.max_epochs, trainer.state.iteration % steps, steps, trainer.state.output * args.gradient_accumulation_steps)) '''================add check point========================''' checkpoint_handler = ModelCheckpoint(checkpoint_dir, 'checkpoint', n_saved=3) trainer.add_event_handler( Events.EPOCH_COMPLETED, checkpoint_handler, {'BertClassificationModel': getattr(model, 'module', model) }) # "getattr" take care of distributed encapsulation '''==============run trainer=============================''' trainer.run(train_data_loader, max_epochs=args.n_epochs)
def train(self, config, **kwargs): """Trains a model on the given configurations. :param config: A training configuration. Note that all parameters in the config can also be manually adjusted with --ARG=VALUE :param **kwargs: parameters to overwrite yaml config """ from pycocoevalcap.cider.cider import Cider config_parameters = train_util.parse_config_or_kwargs(config, **kwargs) config_parameters["seed"] = self.seed outputdir = os.path.join( config_parameters["outputpath"], config_parameters["model"], "{}_{}".format( datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%m'), uuid.uuid1().hex)) # Early init because of creating dir checkpoint_handler = ModelCheckpoint( outputdir, "run", n_saved=1, require_empty=False, create_dir=True, score_function=lambda engine: engine.state.metrics["score"], score_name="score") logger = train_util.genlogger(os.path.join(outputdir, "train.log")) # print passed config parameters logger.info("Storing files in: {}".format(outputdir)) train_util.pprint_dict(config_parameters, logger.info) zh = config_parameters["zh"] vocabulary = torch.load(config_parameters["vocab_file"]) train_loader, cv_loader, info = self._get_dataloaders( config_parameters, vocabulary) config_parameters["inputdim"] = info["inputdim"] cv_key2refs = info["cv_key2refs"] logger.info("<== Estimating Scaler ({}) ==>".format( info["scaler"].__class__.__name__)) logger.info("Feature: {} Input dimension: {} Vocab Size: {}".format( config_parameters["feature_file"], info["inputdim"], len(vocabulary))) model = self._get_model(config_parameters, len(vocabulary)) if "pretrained_word_embedding" in config_parameters: embeddings = np.load( config_parameters["pretrained_word_embedding"]) model.load_word_embeddings( embeddings, tune=config_parameters["tune_word_embedding"], projection=True) model = model.to(self.device) train_util.pprint_dict(model, logger.info, formatter="pretty") optimizer = getattr(torch.optim, config_parameters["optimizer"])( model.parameters(), **config_parameters["optimizer_args"]) train_util.pprint_dict(optimizer, logger.info, formatter="pretty") criterion = torch.nn.CrossEntropyLoss().to(self.device) crtrn_imprvd = train_util.criterion_improver( config_parameters['improvecriterion']) def _train_batch(engine, batch): model.train() with torch.enable_grad(): optimizer.zero_grad() output = self._forward(model, batch, "train") loss = criterion(output["packed_logits"], output["targets"]).to(self.device) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) optimizer.step() output["loss"] = loss.item() return output trainer = Engine(_train_batch) RunningAverage(output_transform=lambda x: x["loss"]).attach( trainer, "running_loss") pbar = ProgressBar(persist=False, ascii=True, ncols=100) pbar.attach(trainer, ["running_loss"]) key2pred = {} def _inference(engine, batch): model.eval() keys = batch[2] with torch.no_grad(): output = self._forward(model, batch, "validation") seqs = output["seqs"].cpu().numpy() for (idx, seq) in enumerate(seqs): if keys[idx] in key2pred: continue candidate = self._convert_idx2sentence(seq, vocabulary, zh) key2pred[keys[idx]] = [ candidate, ] return output metrics = { "loss": Loss(criterion, output_transform=lambda x: (x["packed_logits"], x["targets"])) } evaluator = Engine(_inference) def eval_cv(engine, key2pred, key2refs): scorer = Cider(zh=zh) score, scores = scorer.compute_score(key2refs, key2pred) engine.state.metrics["score"] = score key2pred.clear() evaluator.add_event_handler(Events.EPOCH_COMPLETED, eval_cv, key2pred, cv_key2refs) for name, metric in metrics.items(): metric.attach(evaluator, name) trainer.add_event_handler(Events.EPOCH_COMPLETED, train_util.log_results, evaluator, cv_loader, logger.info, ["loss", "score"]) evaluator.add_event_handler( Events.EPOCH_COMPLETED, train_util.save_model_on_improved, crtrn_imprvd, "score", { "model": model.state_dict(), "config": config_parameters, "scaler": info["scaler"] }, os.path.join(outputdir, "saved.pth")) scheduler = getattr(torch.optim.lr_scheduler, config_parameters["scheduler"])( optimizer, **config_parameters["scheduler_args"]) evaluator.add_event_handler(Events.EPOCH_COMPLETED, train_util.update_lr, scheduler, "score") evaluator.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, { "model": model, }) trainer.run(train_loader, max_epochs=config_parameters["epochs"]) return outputdir
def fit(self, dataset, fold=0, train_split='train', valid_split='val'): """Fit the predictor model. Args: - dataset: temporal, static, label, time, treatment information - fold: Cross validation fold - train_split: training set splitting parameter - valid_split: validation set splitting parameter Returns: - self.predictor_model: trained predictor model """ train_x, train_y = self._data_preprocess(dataset, fold, train_split) valid_x, valid_y = self._data_preprocess(dataset, fold, valid_split) train_dataset = torch.utils.data.dataset.TensorDataset( self._make_tensor(train_x), self._make_tensor(train_y)) valid_dataset = torch.utils.data.dataset.TensorDataset( self._make_tensor(valid_x), self._make_tensor(valid_y)) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) val_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=self.batch_size, shuffle=True) if self.predictor_model is None: self.predictor_model = TransformerModule( self.task, dataset.problem, train_x.shape[-1], self.h_dim, train_y.shape[-1], self.n_head, self.n_layer).to(self.device) self.optimizer = torch.optim.Adam( self.predictor_model.parameters(), lr=self.learning_rate) self.predictor_model.train() # classification vs regression # static vs dynamic trainer = create_supervised_trainer(self.predictor_model, self.optimizer, self.predictor_model.loss_fn) evaluator = create_supervised_evaluator( self.predictor_model, metrics={'loss': Loss(self.predictor_model.loss_fn)}) # model check point checkpoint_handler = ModelCheckpoint(self.model_path, self.model_id, n_saved=1, create_dir=True, require_empty=False) trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1), checkpoint_handler, {'model': self.predictor_model}) # early stopping def score_function(engine): val_loss = engine.state.metrics['loss'] return -val_loss early_stopping_handler = EarlyStopping(patience=10, score_function=score_function, trainer=trainer) evaluator.add_event_handler(Events.COMPLETED, early_stopping_handler) # evaluation loss @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(trainer): evaluator.run(val_loader) metrics = evaluator.state.metrics print("Validation Results - Epoch[{}] Avg loss: {:.2f}".format( trainer.state.epoch, metrics['loss'])) trainer.run(train_loader, max_epochs=self.epoch) return self.predictor_model
def run(output_path, config): device = "cuda" local_rank = config['local_rank'] distributed = backend is not None if distributed: torch.cuda.set_device(local_rank) device = "cuda" rank = dist.get_rank() if distributed else 0 # Rescale batch_size and num_workers ngpus_per_node = torch.cuda.device_count() ngpus = dist.get_world_size() if distributed else 1 batch_size = config['batch_size'] // ngpus num_workers = int( (config['num_workers'] + ngpus_per_node - 1) / ngpus_per_node) train_labelled_loader, test_loader = \ get_train_test_loaders(path=config['data_path'], batch_size=batch_size, distributed=distributed, num_workers=num_workers) model = get_model(config['model']) model = model.to(device) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[ local_rank, ], output_device=local_rank) optimizer = optim.SGD(model.parameters(), lr=config['learning_rate'], momentum=config['momentum'], weight_decay=config['weight_decay'], nesterov=True) criterion = nn.CrossEntropyLoss().to(device) le = len(train_labelled_loader) milestones_values = [(0, 0.0), (le * config['num_warmup_epochs'], config['learning_rate']), (le * config['num_epochs'], 0.0)] lr_scheduler = PiecewiseLinear(optimizer, param_name="lr", milestones_values=milestones_values) def _prepare_batch(batch, device, non_blocking): x, y = batch return (convert_tensor(x, device=device, non_blocking=non_blocking), convert_tensor(y, device=device, non_blocking=non_blocking)) def process_function(engine, labelled_batch): x, y = _prepare_batch(labelled_batch, device=device, non_blocking=True) model.train() # Supervised part y_pred = model(x) loss = criterion(y_pred, y) optimizer.zero_grad() loss.backward() optimizer.step() return { 'batch loss': loss.item(), } trainer = Engine(process_function) if not hasattr(lr_scheduler, "step"): trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler) else: trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: lr_scheduler.step()) metric_names = [ 'batch loss', ] def output_transform(x, name): return x[name] for n in metric_names: # We compute running average values on the output (batch loss) across all devices RunningAverage(output_transform=partial(output_transform, name=n), epoch_bound=False, device=device).attach(trainer, n) if rank == 0: checkpoint_handler = ModelCheckpoint(dirname=output_path, filename_prefix="checkpoint") trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), checkpoint_handler, { 'model': model, 'optimizer': optimizer }) ProgressBar(persist=True, bar_format="").attach(trainer, event_name=Events.EPOCH_STARTED, closing_event_name=Events.COMPLETED) if config['display_iters']: ProgressBar(persist=False, bar_format="").attach(trainer, metric_names=metric_names) tb_logger = TensorboardLogger(log_dir=output_path) tb_logger.attach(trainer, log_handler=tbOutputHandler( tag="train", metric_names=metric_names), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=tbOptimizerParamsHandler(optimizer, param_name="lr"), event_name=Events.ITERATION_STARTED) metrics = { "accuracy": Accuracy(device=device if distributed else None), "loss": Loss(criterion, device=device if distributed else None) } evaluator = create_supervised_evaluator(model, metrics=metrics, device=device, non_blocking=True) train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device, non_blocking=True) def run_validation(engine): torch.cuda.synchronize() train_evaluator.run(train_labelled_loader) evaluator.run(test_loader) trainer.add_event_handler(Events.EPOCH_STARTED(every=3), run_validation) trainer.add_event_handler(Events.COMPLETED, run_validation) if rank == 0: if config['display_iters']: ProgressBar(persist=False, desc="Train evaluation").attach(train_evaluator) ProgressBar(persist=False, desc="Test evaluation").attach(evaluator) tb_logger.attach(train_evaluator, log_handler=tbOutputHandler(tag="train", metric_names=list( metrics.keys()), another_engine=trainer), event_name=Events.COMPLETED) tb_logger.attach(evaluator, log_handler=tbOutputHandler(tag="test", metric_names=list( metrics.keys()), another_engine=trainer), event_name=Events.COMPLETED) # Store the best model def default_score_fn(engine): score = engine.state.metrics['accuracy'] return score score_function = default_score_fn if not hasattr( config, "score_function") else config.score_function best_model_handler = ModelCheckpoint( dirname=output_path, filename_prefix="best", n_saved=3, global_step_transform=global_step_from_engine(trainer), score_name="val_accuracy", score_function=score_function) evaluator.add_event_handler(Events.COMPLETED, best_model_handler, { 'model': model, }) trainer.run(train_labelled_loader, max_epochs=config['num_epochs']) if rank == 0: tb_logger.close()
def _loss_fn(i, j): return loss(i[0], j) # Create trainer device = torch.device("cuda:0") trainer = create_supervised_trainer(net, opt, _loss_fn, device, False, output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item(), y]) checkpoint_handler = ModelCheckpoint('./', 'net', n_saved=10, require_empty=False) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ 'net': net, 'opt': opt }) dice_metric = MeanDice(add_sigmoid=True, output_transform=lambda output: (output[0][0], output[2])) dice_metric.attach(trainer, "Training Dice") logging.basicConfig(stream=sys.stdout, level=logging.INFO) stats_logger = StatsHandler()
print("TEST EVAL") evaluator.run(test_ld) test_wra_vle = round(evaluator.state.metrics["WRA"], 3) report = f"{RUN_NAME};{test_wra_vle}\n" with EVALUATION_RESULTS_FILE_PATH.open(mode='a') as f: f.writelines(report) print(f"TRAINING IS DONE FOR {RUN_NAME} RUN.") pbar = ProgressBar() checkpointer = ModelCheckpoint( CHECKPOINTS_RUN_DIR_PATH, filename_prefix=RUN_NAME.lower(), n_saved=None, score_function=lambda engine: round(engine.state.metrics['WRA'], 3), score_name='WRA', atomic=True, require_empty=True, create_dir=True, archived=False, global_step_transform=global_step_from_engine(trainer)) nan_handler = TerminateOnNan() coslr = CosineAnnealingScheduler(opt, "lr", start_value=LR, end_value=LR / 4, cycle_size=TOTAL_UPDATE_STEPS // 1) evaluator.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'_': mude})