示例#1
0
    def init_function(h_model):
        h_criterion = torch.nn.CrossEntropyLoss()
        h_evaluator = SupervisedEvaluator(model=h_model, criterion=h_criterion, device=device)
        h_train_evaluator = SupervisedEvaluator(model=h_model, criterion=h_criterion, device=device)
        h_optimizer = torch.optim.Adam(params=h_model.parameters(), lr=1e-3)
        h_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(h_optimizer, 'max', verbose=True, patience=5,
                                                                    factor=0.5)
        h_trainer = SupervisedTrainer(model=h_model, optimizer=h_optimizer, criterion=h_criterion, device=device)

        # Tqdm logger
        h_pbar = ProgressBar(persist=False, bar_format=config.IGNITE_BAR_FORMAT)
        h_pbar.attach(h_trainer.engine, metric_names='all')
        h_tqdm_logger = TqdmLogger(pbar=h_pbar)
        # noinspection PyTypeChecker
        h_tqdm_logger.attach_output_handler(
            h_evaluator.engine,
            event_name=Events.COMPLETED,
            tag="validation",
            global_step_transform=global_step_from_engine(h_trainer.engine),
        )
        # noinspection PyTypeChecker
        h_tqdm_logger.attach_output_handler(
            h_train_evaluator.engine,
            event_name=Events.COMPLETED,
            tag="train",
            global_step_transform=global_step_from_engine(h_trainer.engine),
        )

        # Learning rate scheduling
        # The PyTorch Ignite LRScheduler class does not work with ReduceLROnPlateau
        h_evaluator.engine.add_event_handler(Events.COMPLETED,
                                             lambda engine: h_lr_scheduler.step(engine.state.metrics['accuracy']))

        # Model checkpoints
        h_handler = ModelCheckpoint(config.MODELS_DIR, run.replace('/', '-'), n_saved=1, create_dir=True,
                                    require_empty=False, score_name='acc',
                                    score_function=lambda engine: engine.state.metrics['accuracy'],
                                    global_step_transform=global_step_from_engine(trainer.engine))
        h_evaluator.engine.add_event_handler(Events.EPOCH_COMPLETED, h_handler, {'m': model})

        # Early stopping
        h_es_handler = EarlyStopping(patience=15,
                                     min_delta=0.0001,
                                     score_function=lambda engine: engine.state.metrics['accuracy'],
                                     trainer=h_trainer.engine, cumulative_delta=True)
        h_es_handler.logger.setLevel(logging.DEBUG)
        h_evaluator.engine.add_event_handler(Events.COMPLETED, h_es_handler)

        return h_trainer, h_train_evaluator, h_evaluator
示例#2
0
def test_global_step_from_engine():

    iteration = 12
    epoch = 23

    trainer = Engine(lambda e, b: None)
    trainer.state.iteration = iteration
    trainer.state.epoch = epoch

    gst = global_step_from_engine(trainer)
    assert gst(MagicMock(), Events.EPOCH_COMPLETED) == epoch

    gst = global_step_from_engine(trainer,
                                  custom_event_name=Events.ITERATION_COMPLETED)
    assert gst(MagicMock(), Events.EPOCH_COMPLETED) == iteration
示例#3
0
    def update_hparams(engine, finished=False):
        hparam_dict['total_iterations'] = global_step_from_engine(engine)(engine, Events.ITERATION_COMPLETED)
        hparam_dict['total_epochs'] = global_step_from_engine(engine)(engine, Events.EPOCH_COMPLETED)
        hparam_dict['timeout'] = not finished

        if hparam_dict['train_set_size'] is None:
            hparam_dict['train_set_size'] = hparam_dict['training_set_size']

        try:
            shutil.copyfile(os.path.join(output_dir, 'hparams.pickle'),
                            os.path.join(output_dir, 'hparams.pickle.backup'))
            with open(os.path.join(output_dir, 'hparams.pickle'), 'wb') as f:
                pickle.dump(hparam_dict, f)
        except AttributeError as e:
            print('Could not pickle one of the total vars.', e)
            os.replace(os.path.join(output_dir, 'hparams.pickle.backup'), os.path.join(output_dir, 'hparams.pickle'))
示例#4
0
        def dev_fn(engine, batch):
            model.eval()
            optimizer.zero_grad()
            with torch.no_grad():
                batch = tuple(t.to(self.device) for t in batch)
                labels = batch[3]

                inputs = {
                    "input_ids": batch[0],
                    "token_type_ids": batch[1],
                    "attention_mask": batch[2],
                    "label_ids": labels
                }

                loss, sequence_tags = model(**inputs)

            score = f1_score(labels.detach().cpu().numpy(),
                             y_pred=sequence_tags.detach().cpu().numpy(),
                             average="macro")

            if self.n_gpu > 1:
                loss = loss.mean()

            ## tensorboard
            global_step = global_step_from_engine(engine)(
                engine, engine.last_event_name)
            # tb_writer.add_scalar('learning_rate', scheduler.get_lr()[0], global_step)
            tb_writer.add_scalar('dev_loss', loss.item(), global_step)
            tb_writer.add_scalar('dev_score', score, global_step)

            return loss.item(), score
示例#5
0
    def _build_objects(acc_list):

        model = DummyModel().to(device)
        optim = torch.optim.SGD(model.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.5)

        def update_fn(engine, batch):
            x = torch.rand((4, 1)).to(device)
            optim.zero_grad()
            y = model(x)
            loss = y.pow(2.0).sum()
            loss.backward()
            if idist.has_xla_support:
                import torch_xla.core.xla_model as xm

                xm.optimizer_step(optim, barrier=True)
            else:
                optim.step()
            lr_scheduler.step()

        trainer = Engine(update_fn)

        evaluator = Engine(lambda e, b: None)
        acc_iter = iter(acc_list)

        @evaluator.on(Events.EPOCH_COMPLETED)
        def setup_result():
            evaluator.state.metrics["accuracy"] = next(acc_iter)

        @trainer.on(Events.EPOCH_COMPLETED)
        def run_eval():
            evaluator.run([0, 1, 2])

        def score_function(engine):
            return engine.state.metrics["accuracy"]

        save_handler = DiskSaver(dirname, create_dir=True, require_empty=False)
        early_stop = EarlyStopping(score_function=score_function,
                                   patience=2,
                                   trainer=trainer)
        evaluator.add_event_handler(Events.COMPLETED, early_stop)

        checkpointer = Checkpoint(
            {
                "trainer": trainer,
                "model": model,
                "optim": optim,
                "lr_scheduler": lr_scheduler,
                "early_stop": early_stop,
            },
            save_handler,
            include_self=True,
            global_step_transform=global_step_from_engine(trainer),
        )
        evaluator.add_event_handler(Events.COMPLETED, checkpointer)

        return trainer, evaluator, model, optim, lr_scheduler, early_stop, checkpointer
    def setup_checkpoint_saver(self, to_save):
        if self.hparams.checkpoint_params is not None:
            from ignite.handlers import Checkpoint, DiskSaver, global_step_from_engine

            handler = Checkpoint(to_save, DiskSaver(self.hparams.checkpoint_params["save_dir"], require_empty=False), n_saved=self.hparams.checkpoint_params["n_saved"],
                                filename_prefix=self.hparams.checkpoint_params["prefix_name"], score_function=self.score_function, score_name="score", 
                                global_step_transform=global_step_from_engine(self.trainer))

            self.evaluator.add_event_handler(Events.COMPLETED, handler)
        def train_fn(engine, batch):
            model.train()
            optimizer.zero_grad()
            batch = tuple(t.to(self.device) for t in batch)
            labels = batch[3]
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "token_type_ids": batch[2],
                "labels": labels,
                "is_nested": self.is_nested
            }

            loss, sequence_tags = model(**inputs)
            if not self.is_nested:
                score = (
                    sequence_tags == labels).float().detach().cpu().numpy()

                condition_1 = (labels != self.label_list.index("O")
                               ).detach().cpu().numpy()
                condition_2 = (labels != self.label_list.index("<PAD>")
                               ).detach().cpu().numpy()
                patten = np.logical_and(condition_1, condition_2)
                score = score[patten].mean()
            else:
                '''
                y_pred = sequence_tags.detach().cpu().numpy()
                labels_np = labels.detach().cpu().numpy()
                score = ((y_pred > self.multi_label_threshold) == (labels_np > 0)).mean()
                '''
                score = ((sequence_tags > self.multi_label_threshold) == (
                    labels > 0)).float().detach().cpu().numpy()
                condition_1 = (labels != self.label_list.index("O")
                               ).detach().cpu().numpy()
                condition_2 = (labels != self.label_list.index("<PAD>")
                               ).detach().cpu().numpy()
                patten = np.logical_and(condition_1, condition_2)
                score = score[patten].mean()

            if self.n_gpu > 1:
                loss = loss.mean()

            ## tensorboard
            global_step = global_step_from_engine(engine)(
                engine, engine.last_event_name)
            tb_writer.add_scalar('learning_rate',
                                 scheduler.get_lr()[0], global_step)
            tb_writer.add_scalar('train_loss', loss.item(), global_step)
            tb_writer.add_scalar('train_score', score.item(), global_step)

            loss.backward()
            optimizer.step()
            scheduler.step()
            model.zero_grad()
            return loss.item(), score.item()
示例#8
0
def save_checkpoint(trainer, evaluator, to_save, score_function, save_dir,
                    n_saved, prefix_name):
    if save_dir is not None and score_function is not None:
        handler = Checkpoint(
            to_save,
            DiskSaver(save_dir, require_empty=False),
            n_saved=n_saved,
            filename_prefix=prefix_name,
            score_function=score_function,
            score_name="score",
            global_step_transform=global_step_from_engine(trainer))

        evaluator.add_event_handler(Events.COMPLETED, handler)
示例#9
0
 def store(self, engine: ignite.engine.Engine):
     """Evaluation engine store state with computed metrics, that will be send to main logger"""
     metrics = {}
     if not hasattr(engine.state, 'metrics') or len(
             engine.state.metrics) == 0:
         return
     kwargs = dict(current_step=global_step_from_engine(self.train_engine)(
         self.train_engine,
         self.train_engine.last_event_name)) if self.train_engine else {}
     for key, val in engine.state.metrics.items():
         metric_name = key
         metrics[metric_name] = val
     self.liveplot.update(metrics, **kwargs)
     self.liveplot.send()
示例#10
0
        # Validset evaluation
        valid_state = valid_evaluator.run(validset_testset[0])
        valid_metrics = {f'valid_{n}': float(v) for n, v in valid_state.metrics.items()}
        for n, v in valid_metrics.items():
            mlflow.log_metric(n, v, step=engine.state.epoch)

        if not is_nni_run_standalone():
            # TODO: make sure `valid_state.metrics` is ordered so that reported default metric to NNI is always the same
            nni.report_intermediate_result({'default': valid_state.metrics.values()[0], **train_metrics, **valid_metrics})

    if backend_conf.rank == 0:
        event = Events.ITERATION_COMPLETED(every=hp['log_progress_every_iters'] if hp['log_progress_every_iters'] else None)
        ProgressBar(persist=False, desc='Train evaluation').attach(train_evaluator, event_name=event)
        ProgressBar(persist=False, desc='Test evaluation').attach(valid_evaluator)

        log_handler = OutputHandler(tag='train', metric_names=list(metrics.keys()), global_step_transform=global_step_from_engine(trainer))
        tb_logger.attach(train_evaluator, log_handler=log_handler, event_name=Events.COMPLETED)

        log_handler = OutputHandler(tag='test', metric_names=list(metrics.keys()), global_step_transform=global_step_from_engine(trainer))
        tb_logger.attach(valid_evaluator, log_handler=log_handler, event_name=Events.COMPLETED)

        # Store the best model by validation accuracy:
        common.save_best_model_by_val_score(str(output_path), valid_evaluator, model=model, metric_name='accuracy', n_saved=3, trainer=trainer, tag='val')

        if hp['log_grads_every_iters'] is not None and hp['log_grads_every_iters'] > 0:
            tb_logger.attach(trainer, log_handler=GradsHistHandler(model, tag=model.__class__.__name__), event_name=Events.ITERATION_COMPLETED(every=hp['log_grads_every_iters']))

    if hp['crash_iteration'] is not None and hp['crash_iteration'] >= 0:
        @trainer.on(Events.ITERATION_STARTED(once=hp['crash_iteration']))
        def _(engine):
            raise Exception('STOP at iteration: {}'.format(engine.state.iteration))
示例#11
0
def run(model,
        dataset,
        val_dataset,
        device=None,
        optimizer=None,
        criterion=None,
        epochs=10,
        batch_size=1,
        log_interval=10,
        model_name='unknown',
        log_dir=None,
        save=True,
        model_name_prefix='',
        path=None):
    start_time = time.time()
    if path is None:
        path = join(MODELS_PATH, model_name)

    # writer = None
    # if log_dir is not None and log_dir != '':
    #     current_time = datetime.now().strftime('%b%d_%H-%M-%S')
    #     writer = SummaryWriter(log_dir=log_dir+current_time)

    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f'Model device: {device}')

    model.to(device)

    train_loader = DataLoader(dataset=dataset,
                              batch_size=batch_size,
                              collate_fn=batcher(device))

    val_loader = DataLoader(dataset=val_dataset,
                            batch_size=batch_size,
                            collate_fn=batcher(device))

    if optimizer is None:
        # create the optimizer
        optimizer = torch.optim.Adagrad(model.parameters(),
                                        lr=0.01,
                                        weight_decay=1e-4)
    if criterion is None:
        criterion = nn.CrossEntropyLoss(
            weight=torch.tensor([1.0, 7.0], device=device))

    def update_model(engine, batch):
        g = batch.graph

        optimizer.zero_grad()
        outputs = model(g)
        loss = criterion(outputs, batch.labels)

        # if torch.isnan(loss).any():
        #     print(f'Loss is NAN at step: {engine.state.iteration}')
        #     return 0
        loss.backward()
        optimizer.step()
        return loss.item()

    trainer = Engine(update_model)

    training_history = {'accuracy': [], 'loss': []}
    whole_training_history = {'loss': []}
    validation_history = {'accuracy': [], 'loss': []}
    val_metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)}
    train_evaluator = create_my_supervised_evaluator(model,
                                                     metrics=val_metrics)
    val_evaluator = create_my_supervised_evaluator(model, metrics=val_metrics)

    handler = EarlyStopping(patience=50,
                            score_function=score_function,
                            trainer=trainer)
    val_evaluator.add_event_handler(Events.COMPLETED, handler)

    to_save = {'model': model}
    if save:
        handler = Checkpoint(
            to_save,
            DiskSaver(path, create_dir=True, require_empty=False),
            n_saved=1,
            score_function=score_function,
            score_name="loss",
            filename_prefix=model_name_prefix,
            global_step_transform=global_step_from_engine(trainer))

        val_evaluator.add_event_handler(Events.COMPLETED, handler)

    @trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
    def log_training_loss(engine):
        whole_training_history['loss'].append(engine.state.output)
        # if writer is not None:
        #     writer.add_scalar("training/loss", engine.state.output, engine.state.iteration)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        train_evaluator.run(train_loader)
        metrics = train_evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_loss = metrics["loss"]
        print(
            f'Epoch {engine.state.epoch:05d} | '
            f'Train: loss {avg_loss:7.4f}, acc {avg_accuracy:7.4f} | ',
            end='')
        training_history['accuracy'].append(avg_accuracy)
        training_history['loss'].append(avg_loss)
        # if writer is not None:
        #     writer.add_scalar("training/avg_loss", avg_loss, engine.state.epoch)
        #     writer.add_scalar("training/avg_accuracy", avg_accuracy, engine.state.epoch)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        val_evaluator.run(val_loader)
        metrics = val_evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_loss = metrics["loss"]
        print(f'Val: loss {avg_loss:7.4f}, acc {avg_accuracy:7.4f} |')
        # print(f'Val results   | '
        #       f'Epoch {engine.state.epoch:05d} | '
        #       f'Avg loss {avg_loss:.4f} | '
        #       f'Avg accuracy {avg_accuracy:.4f} |')
        validation_history['accuracy'].append(avg_accuracy)
        validation_history['loss'].append(avg_loss)
        # if writer is not None:
        #     writer.add_scalar("valdation/avg_loss", avg_loss, engine.state.epoch)
        #     writer.add_scalar("valdation/avg_accuracy", avg_accuracy, engine.state.epoch)

    # kick everything off
    trainer.run(train_loader, max_epochs=epochs)

    # if writer is not None:
    #     writer.close()

    print(f'Model trained in {(time.time() - start_time):.1f}s')
    return training_history, validation_history, whole_training_history
示例#12
0
def adv_prune_train_loop(model, params, ds, dset, min_y, base_data, model_id, prune_type, device, batch_size, tpa, max_epochs=5):
    #assert prune_type in ['global_unstructured', 'structured']
    total_prune_amount = tpa
    remove_amount = tpa
    ds_train, ds_valid = ds
    train_set, valid_set = dset
    min_y_train, min_y_val = min_y
    train_set, valid_set = dset
    total_prune_amount = tpa
    original_model = copy.deepcopy(model)
    original_model.eval()
    model_id = f'{model_id}_{prune_type}_pruning_{tpa}_l1'
    valid_freq = 200 * 500 // batch_size // 3

    conv_layers = [model.conv1]
    for sequential in [model.layer1, model.layer2, model.layer3, model.layer4]:
        for bottleneck in sequential:
            conv_layers.extend([bottleneck.conv1, bottleneck.conv2, bottleneck.conv3])
    conv_layers = conv_layers[:22]
    def prune_model(model):
        print(f'pruned model by {total_prune_amount}')
        if prune_type == 'global_unstructured':
            parameters_to_prune = [(layer, 'weight') for layer in conv_layers]
            prune.global_unstructured(
                parameters_to_prune,
                pruning_method=prune.L1Unstructured,
                amount=total_prune_amount,
            )
        else:
            for layer in conv_layers:

    prune_model(model)

    def valid_eval(model, dataset, dataloader, device, label):
        right = 0
        total = 0
        model.eval()
        with torch.no_grad():
            for i, data in tqdm(enumerate(dataloader), total=len(dataset) / dataloader.batch_size):
                data, y = data
                data = data.to(device)
                y = y.to(device) - label
                ans = model.forward(data)
                right += torch.sum(torch.eq(torch.argmax(ans, dim=1), y))
                total += y.shape[0]
        return right/total
    valid_acc = valid_eval(model, valid_set, ds_valid, device, min_y_val)
    print('initial accuracy:', valid_acc.item())
    
    with create_summary_writer(model, ds_train, base_data, model_id, device=device) as writer:
        lr = params['lr']
        mom = params['momentum']
        wd = params['l2_wd']
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=mom, weight_decay=wd)
        sched = ReduceLROnPlateau(optimizer, factor=0.5, patience=5)
        funcs = {'accuracy': Accuracy(), 'loss': Loss(F.cross_entropy)}
        loss = funcs['loss']._loss_fn

        acc_metric = Accuracy(device=device)
        loss_metric = Loss(F.cross_entropy, device=device)

        acc_val_metric = Accuracy(device=device)
        loss_val_metric = Loss(F.cross_entropy, device=device)

#         attack = GradientSignAttack(original_model, loss_fn=loss, eps=0.2)

        def train_step(engine, batch):
            model.train()
            x, y = batch
            x = x.to(device)
            y = y.to(device) - min_y_train
#             with ctx_noparamgrad_and_eval(model):
#                 x_adv = attack.perturb(x, y)
#             optimizer.zero_grad()
#             x = torch.cat((x, x_adv))
#             y = torch.cat((y, y))
            ans = model.forward(x)
            l = loss(ans, y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            with torch.no_grad():
                for layer in conv_layers:
                    layer.weight *= layer.weight_mask
            return l.item()

        trainer = Engine(train_step)

        def train_eval_step(engine, batch):
            model.eval()
            x, y = batch
            x = x.to(device)
            y = y.to(device) - min_y_train
#             x_adv = attack.perturb(x, y)
#             x = torch.cat((x, x_adv))
#             y = torch.cat((y, y))
            with torch.no_grad():
                ans = model.forward(x)
            return ans, y

        train_evaluator = Engine(train_eval_step)
        acc_metric.attach(train_evaluator, "accuracy")
        loss_metric.attach(train_evaluator, 'loss')

        def validation_step(engine, batch):
            model.eval()
            x, y = batch
            x = x.to(device)
            y = y.to(device) - min_y_val
#             x_adv = attack.perturb(x, y)
#             x = torch.cat((x, x_adv))
#             y = torch.cat((y, y))
            with torch.no_grad():
                ans = model.forward(x)
            return ans, y

        valid_evaluator = Engine(validation_step)
        acc_val_metric.attach(valid_evaluator, "accuracy")
        loss_val_metric.attach(valid_evaluator, 'loss')

        @trainer.on(Events.ITERATION_COMPLETED(every=valid_freq))
        def log_validation_results(engine):
            valid_evaluator.run(ds_valid)
            metrics = valid_evaluator.state.metrics
            valid_avg_accuracy = metrics['accuracy']
            avg_nll = metrics['loss']
            print("Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
                  .format(engine.state.epoch, valid_avg_accuracy, avg_nll))
            writer.add_scalar("validation/avg_loss", avg_nll, engine.state.epoch)
            writer.add_scalar("validation/avg_accuracy", valid_avg_accuracy, engine.state.epoch)
            writer.add_scalar("validation/avg_error", 1. - valid_avg_accuracy, engine.state.epoch)

        @trainer.on(Events.EPOCH_COMPLETED)
        def lr_scheduler(engine):
            metrics = valid_evaluator.state.metrics
            avg_nll = metrics['accuracy']
            sched.step(avg_nll)

        @trainer.on(Events.ITERATION_COMPLETED(every=100))
        def log_training_loss(engine):
            batch = engine.state.batch
            ds = DataLoader(TensorDataset(*batch), batch_size=batch_size)
            train_evaluator.run(ds)
            metrics = train_evaluator.state.metrics
            accuracy = metrics['accuracy']
            nll = metrics['loss']
            iter = (engine.state.iteration - 1) % len(ds_train) + 1
            if (iter % 50) == 0:
                print("Epoch[{}] Iter[{}/{}] Accuracy: {:.2f} Loss: {:.2f}"
                      .format(engine.state.epoch, iter, len(ds_train), accuracy, nll))
            writer.add_scalar("batchtraining/detloss", nll, engine.state.epoch)
            writer.add_scalar("batchtraining/accuracy", accuracy, engine.state.iteration)
            writer.add_scalar("batchtraining/error", 1. - accuracy, engine.state.iteration)
            writer.add_scalar("batchtraining/loss", engine.state.output, engine.state.iteration)

        @trainer.on(Events.EPOCH_COMPLETED)
        def log_lr(engine):
            writer.add_scalar("lr", optimizer.param_groups[0]['lr'], engine.state.epoch)

        @trainer.on(Events.ITERATION_COMPLETED(every=valid_freq))
        def validation_value(engine):
            metrics = valid_evaluator.state.metrics
            valid_avg_accuracy = metrics['accuracy']
            return valid_avg_accuracy

        to_save = {'model': model}
        handler = Checkpoint(to_save, DiskSaver(os.path.join(base_data, model_id),
                                                create_dir=True),
                             score_function=validation_value, score_name="val_acc",
                             global_step_transform=global_step_from_engine(trainer),
                             n_saved=None)

        # kick everything off
        trainer.add_event_handler(Events.ITERATION_COMPLETED(every=valid_freq), handler)
        trainer.run(ds_train, max_epochs=max_epochs)
示例#13
0
    def __init__(self, trainer_args={}):

        self.log_prefix = trainer_args.get("model_name")

        self.save_path = os.path.join(trainer_args.get("fold_dir"))

        create_dirs(self.save_path)

        self.device = trainer_args.get("device")
        self.epochs = trainer_args.get("epochs")
        self.early_stopping = trainer_args.get("early_stopping", None)

        self.exclude_anomalies = trainer_args.get("exclude_anomalies")
        self.include_metrics = trainer_args.get("include_metrics")

        self.model_class = trainer_args.get("model_class")
        self.model_args = trainer_args.get("model_args", {})

        self.optimizer_class = trainer_args.get("optimizer_class")
        self.optimizer_args = trainer_args.get("optimizer_args", {})

        self.training_data_stats = trainer_args.get("training_data_stats")

        self.loss = trainer_args.get("loss_func")
        self.node_classes = trainer_args.get("node_classes")
        self.graph_classes = trainer_args.get("graph_classes")

        self.score_function = score_function

        self.resume_from_checkpoint = trainer_args.get(
            "resume_from_checkpoint", {})

        # create model, optimizer
        self.model = self.model_class({
            **self.model_args, "pred_collector_function":
            self._pred_collector_function
        }).to(self.device).double()

        self.optimizer = self.optimizer_class(self.model.parameters(),
                                              **self.optimizer_args)
        self.trainer = create_supervised_trainer(
            self.model,
            self.optimizer,
            loss_fn=self.loss,
            device=self.device,
            non_blocking=True,
            output_transform=lambda x, y, y_pred, loss:
            (y_pred, y)  # so that loss-metric can work with transformed output
        )

        ##################### log some values ################################
        if not len(
                trainer_args.get("resume_from_checkpoint", {})
        ):  # only print all this on initial training start, not on resume
            self.custom_print("Device:", self.device)
            self.custom_print("Max. epochs:", self.epochs)
            self.custom_print("Early stopping:", self.early_stopping)
            self.custom_print("Excluded Anomalies:", self.exclude_anomalies)
            self.custom_print("Included Metrics:", self.include_metrics)
            self.custom_print("Loss class:", self.loss)
            self.custom_print("Node anomaly classes:", self.node_classes)
            self.custom_print("Graph anomaly classes:", self.graph_classes)
            self.custom_print("Model class:", self.model_class)
            self.custom_print("Model args:", json.dumps(self.model_args))
            self.custom_print("Model - All Parameters:", self.model.all_params)
            self.custom_print("Model - Trainable Parameters:",
                              self.model.all_trainable_params)
            self.custom_print("Optimizer args:",
                              json.dumps(self.optimizer_args))
            self.custom_print("Training Data Statistics:",
                              self.training_data_stats)
            self.custom_print("Train indices",
                              trainer_args.get("train_indices"))
            self.custom_print("Val indices", trainer_args.get("val_indices"))
            self.custom_print("Test indices", trainer_args.get("test_indices"))
        ######################################################################

        # configure behavior for early stopping
        self.stopper = None
        if self.early_stopping:
            self.stopper = EarlyStopping(patience=self.early_stopping,
                                         score_function=self.score_function,
                                         trainer=self.trainer)

        # configure behavior for checkpointing
        to_save: dict = {
            "model_state_dict": self.model,
            "optimizer_state_dict": self.optimizer,
            "trainer_state_dict": self.trainer
        }
        if self.stopper:
            to_save["stopper_state_dict"] = self.stopper

        save_handler = DiskSaver(self.save_path,
                                 create_dir=True,
                                 require_empty=False,
                                 atomic=True)

        # save the best checkpoints
        self.best_checkpoint_handler = Checkpoint(
            to_save,
            save_handler,
            filename_prefix=f"{self.log_prefix}_best",
            score_name="val_loss",
            score_function=self.score_function,
            include_self=True,
            global_step_transform=global_step_from_engine(self.trainer),
            n_saved=5)

        # save the latest checkpoint (important for resuming training)
        self.latest_checkpoint_handler = Checkpoint(
            to_save,
            save_handler,
            filename_prefix=f"{self.log_prefix}_latest",
            include_self=True,
            global_step_transform=global_step_from_engine(self.trainer),
            n_saved=1)

        # resume from checkpoint
        if len(self.resume_from_checkpoint):
            self.model, self.optimizer, self.trainer, self.stopper, self.best_checkpoint_handler, self.latest_checkpoint_handler = self._load_checkpoint(
                self.model,
                self.optimizer,
                self.trainer,
                self.stopper,
                self.best_checkpoint_handler,
                self.latest_checkpoint_handler,
                checkpoint_path_dict=self.resume_from_checkpoint)

        self.persist_collection = False
        self.persist_collection_dict: OrderedDict = OrderedDict()
示例#14
0
def attach_handlers(run, model, optimizer, learning_rule, trainer, evaluator, train_loader, val_loader, params):
    # Metrics
    UnitConvergence(model[0], learning_rule.norm).attach(trainer.engine, 'unit_conv')

    # Tqdm logger
    pbar = ProgressBar(persist=True, bar_format=config.IGNITE_BAR_FORMAT)
    pbar.attach(trainer.engine, metric_names='all')
    tqdm_logger = TqdmLogger(pbar=pbar)
    # noinspection PyTypeChecker
    tqdm_logger.attach_output_handler(
        evaluator.engine,
        event_name=Events.COMPLETED,
        tag="validation",
        global_step_transform=global_step_from_engine(trainer.engine),
    )

    # Evaluator
    evaluator.attach(trainer.engine, Events.EPOCH_COMPLETED(every=100), train_loader, val_loader)

    # Learning rate scheduling
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer,
                                                     lr_lambda=lambda epoch: 1 - epoch / params['epochs'])
    lr_scheduler = LRScheduler(lr_scheduler)
    trainer.engine.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler)

    # Early stopping
    mc_handler = ModelCheckpoint(config.MODELS_DIR, run.replace('/', '-'), n_saved=1, create_dir=True,
                                 require_empty=False,
                                 global_step_transform=global_step_from_engine(trainer.engine))
    trainer.engine.add_event_handler(Events.EPOCH_COMPLETED, mc_handler, {'m': model})

    # Create a TensorBoard logger
    tb_logger = TensorboardLogger(log_dir=os.path.join(config.TENSORBOARD_DIR, run))
    images, labels = next(iter(train_loader))
    tb_logger.writer.add_graph(copy.deepcopy(model).cpu(), images)
    tb_logger.writer.add_hparams(params, {})

    # noinspection PyTypeChecker
    tb_logger.attach_output_handler(
        evaluator.engine,
        event_name=Events.COMPLETED,
        tag="validation",
        metric_names="all",
        global_step_transform=global_step_from_engine(trainer.engine),
    )
    # noinspection PyTypeChecker
    tb_logger.attach_output_handler(
        trainer.engine,
        event_name=Events.EPOCH_COMPLETED,
        tag="train",
        metric_names=["unit_conv"]
    )
    input_shape = tuple(next(iter(train_loader))[0].shape[1:])
    tb_logger.attach(trainer.engine,
                     log_handler=WeightsImageHandler(model, input_shape),
                     event_name=Events.EPOCH_COMPLETED)
    tb_logger.attach(trainer.engine, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.EPOCH_STARTED)
    # tb_logger.attach(trainer.engine,
    #                  log_handler=WeightsScalarHandler(model, layer_names=['linear1', 'linear2']),
    #                  event_name=Events.EPOCH_COMPLETED)
    # tb_logger.attach(trainer.engine,
    #                  log_handler=WeightsHistHandler(model, layer_names=['linear1', 'linear2']),
    #                  event_name=Events.EPOCH_COMPLETED)
    # tb_logger.attach(trainer.engine,
    #                  log_handler=ActivationsHistHandler(model, layer_names=['batch_norm', 'repu']),
    #                  event_name=Events.ITERATION_COMPLETED)
    # tb_logger.attach(trainer.engine,
    #                  log_handler=NumActivationsScalarHandler(model, layer_names=['repu']),
    #                  event_name=Events.ITERATION_COMPLETED)
    # tb_logger.attach(trainer.engine,
    #                  log_handler=ActivationsScalarHandler(model, reduction=torch.mean,
    #                                                       layer_names=['batch_norm', 'repu']),
    #                  event_name=Events.ITERATION_COMPLETED)
    # tb_logger.attach(trainer.engine,
    #                  log_handler=ActivationsScalarHandler(model, reduction=torch.std,
    #                                                       layer_names=['batch_norm', 'repu']),
    #                  event_name=Events.ITERATION_COMPLETED)

    return tb_logger