def test_train_fn_once():
    model, optimizer, device, loss_fn, batch = set_up()
    config = Namespace(use_amp=False)
    engine = Engine(lambda e, b: train_function(config, e, b, model, loss_fn, optimizer, device))
    engine.register_events(*TrainEvents, event_to_attr=train_events_to_attr)
    backward = MagicMock()
    optim = MagicMock()
    engine.add_event_handler(TrainEvents.BACKWARD_COMPLETED(once=3), backward)
    engine.add_event_handler(TrainEvents.OPTIM_STEP_COMPLETED(once=3), optim)
    engine.run([batch] * 5)
    assert hasattr(engine.state, "backward_completed")
    assert hasattr(engine.state, "optim_step_completed")
    assert engine.state.backward_completed == 5
    assert engine.state.optim_step_completed == 5
    assert backward.call_count == 1
    assert optim.call_count == 1
    assert backward.called
    assert optim.called
Beispiel #2
0
def main(args):
    if args.g >= 0 and torch.cuda.is_available():
        device = torch.device(f"cuda:{args.g:d}")
        print(f"GPU mode: {args.g}")
    else:
        device = torch.device("cpu")
        print("CPU mode")

    result_dir = Path(args.r)
    try:
        result_dir.mkdir(parents=True)
    except FileExistsError:
        pass

    mnist_train = MNIST(root=".",
                        train=True,
                        download=True,
                        transform=lambda x: np.expand_dims(
                            np.asarray(x, dtype=np.float32), 0) / 255)
    train_loader = data.DataLoader(mnist_train,
                                   batch_size=args.b,
                                   shuffle=True)

    model = GAN(args.z).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=0.0002, betas=(0.5, 0.9))

    trainer = Engine(GANTrainer(model, opt, device))
    logger = GANLogger(model, train_loader, device)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, logger)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, print_loss(logger))
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              plot_loss(logger, result_dir / "loss.pdf"))
    trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        save_img(model.gen, result_dir / "output_images", args.z, device))
    if args.save_model:
        trainer.add_event_handler(Events.EPOCH_COMPLETED,
                                  save_model(model, result_dir / "models"))

    trainer.run(train_loader, max_epochs=args.e)
Beispiel #3
0
 def attach(self, engine: Engine):
     engine.add_event_handler(Events.ITERATION_COMPLETED, self)
Beispiel #4
0
def train(model,
          model_name,
          train_dataloader,
          eval_dataloader,
          labels_name,
          trainer_name='ocr',
          backbone_url=None):
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    scaler = torch.cuda.amp.GradScaler()

    def _prepare_batch(batch, device=None, non_blocking=False):
        """Prepare batch for training: pass to a device with options.
        """
        images, labels = batch
        images = images.to(device)
        labels = [label.to(device) for label in labels]
        return (images, labels)

    writer = SummaryWriter(log_dir=f'logs/{trainer_name}/{model_name}')
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                              factor=0.5,
                                                              patience=250,
                                                              cooldown=100,
                                                              min_lr=1e-6)

    def _update(engine, batch):
        model.train()
        optimizer.zero_grad()
        x, y = _prepare_batch(batch, device=device)
        # loss = model(x, y)
        # loss.backward()
        # optimizer.step()
        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            loss = model(x, y)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        return loss.item()

    trainer = Engine(_update)
    evaluator = create_supervised_evaluator(
        model,
        prepare_batch=_prepare_batch,
        metrics={'edit_distance': EditDistanceMetric()},
        device=device)

    if path.exists(f'{trainer_name}_{model_name}_checkpoint.pt'):
        checkpoint = torch.load(f'{trainer_name}_{model_name}_checkpoint.pt')
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scaler.load_state_dict(checkpoint['scaler'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        logging.info(
            f'load checkpoint {trainer_name}_{model_name}_checkpoint.pt')
    elif path.exists(f'{model_name}_backbone.pt'):
        pretrained_dict = torch.load(f'{model_name}_backbone.pt')['model']
        model_dict = model.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items()
            if k in model_dict and 'neck.' not in k and 'fc.' not in k
        }
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
        logging.info(f'load transfer parameters from {model_name}_backbone.pt')
    elif backbone_url is not None:
        pretrained_dict = torch.hub.load_state_dict_from_url(backbone_url,
                                                             progress=False)
        model_dict = model.backbone.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
        model.backbone.load_state_dict(model_dict)
        logging.info(f'load backbone from {backbone_url}')

    early_stop_arr = [0.0]

    def early_stop_score_function(engine):
        val_acc = engine.state.metrics['edit_distance']
        if val_acc < 0.8:  # do not early stop when acc is less than 0.9
            early_stop_arr[0] += 0.000001
            return early_stop_arr[0]
        return val_acc

    early_stop_handler = EarlyStopping(
        patience=20, score_function=early_stop_score_function, trainer=trainer)
    evaluator.add_event_handler(Events.COMPLETED, early_stop_handler)

    checkpoint_handler = ModelCheckpoint(f'models/{trainer_name}/{model_name}',
                                         model_name,
                                         n_saved=10,
                                         create_dir=True)
    # trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), checkpoint_handler,
    #                           {'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler})
    trainer.add_event_handler(
        Events.ITERATION_COMPLETED(every=1000), checkpoint_handler, {
            'model': model,
            'optimizer': optimizer,
            'lr_scheduler': lr_scheduler,
            'scaler': scaler
        })

    @trainer.on(Events.ITERATION_COMPLETED(every=10))
    def log_training_loss(trainer):
        lr = optimizer.param_groups[0]['lr']
        logging.info("Epoch[{}]: {} - Loss: {:.4f}, Lr: {}".format(
            trainer.state.epoch, trainer.state.iteration, trainer.state.output,
            lr))
        writer.add_scalar("training/loss", trainer.state.output,
                          trainer.state.iteration)
        writer.add_scalar("training/learning_rate", lr,
                          trainer.state.iteration)

    @trainer.on(Events.ITERATION_COMPLETED(every=10))
    def step_lr(trainer):
        lr_scheduler.step(trainer.state.output)

    @trainer.on(Events.ITERATION_COMPLETED(every=1000))
    def log_training_results(trainer):
        evaluator.run(eval_dataloader)
        metrics = evaluator.state.metrics
        logging.info(
            "Eval Results - Epoch[{}]: {} - Avg edit distance: {:.4f}".format(
                trainer.state.epoch, trainer.state.iteration,
                metrics['edit_distance']))
        writer.add_scalar("evaluation/avg_edit_distance",
                          metrics['edit_distance'], trainer.state.iteration)

    @trainer.on(Events.ITERATION_COMPLETED(every=100))
    def read_lr_from_file(trainer):
        if path.exists('lr.txt'):
            with open('lr.txt', 'r', encoding='utf-8') as f:
                lr = float(f.read())
            for group in optimizer.param_groups:
                group['lr'] = lr

    trainer.run(train_dataloader, max_epochs=1)
Beispiel #5
0
def main(config):
    assert validate_config(config), "ERROR: Config file is invalid. Please see log for details."

    logger.info("INFO: {}".format(config.toDict()))

    # Set the random number generator seed for torch, as we use their dataloaders this will ensure shuffle is constant
    # Remeber to seed custom datasets etc with the same seed
    if config.seed > 0:
        torch.backends.cudnn.deterministic = True
        torch.cuda.manual_seed_all(config.seed)
        torch.manual_seed(config.seed)
        random.seed(config.seed)
        np.random.seed(config.seed)

    if config.device == "cpu" and torch.cuda.is_available():
        logger.warning("WARNING: Not using the GPU")
    elif config.device == "cuda":
        config.device = f"cuda:{config.device_ids[0]}"

    config.nsave = config.nsave if "nsave" in config else 5

    logger.info("INFO: Creating datasets and dataloaders...")
    # Create the training dataset
    dset_train = create_dataset(config.datasets.train)

    # If the validation config has a parameter called split then we ask the training dset for the validation dataset
    # it should be noted that you shouldn't shuffle the dataset in the init of the train dataset if this is the case
    # as only on get_validation_split will we know how to split the data. Unless shuffling is deterministic.
    train_ids = None
    if 'validation' in config.datasets:
        # Ensure we have a full config for validation, this means we don't need t specify everything in the config file
        # only the differences
        config_val = config.datasets.train.copy()
        config_val.update(config.datasets.validation)

        dset_val = create_dataset(config_val)

        loader_val = get_data_loader(dset_val, config_val)
        print("Using validation dataset of {} samples or {} batches".format(len(dset_val), len(loader_val)))
    elif 'includes_validation' in config.datasets.train:
        train_ids, val_ids = dset_train.get_validation_split(config_val)
        loader_val = get_data_loader(dset_train, config.datasets.train, val_ids)
        print("Using validation dataset of {} samples or {} batches".format(len(val_ids), len(loader_val)))
    else:
        logger.warning("WARNING: No validation dataset was specified")
        dset_val = None
        loader_val = None

    loader_train = get_data_loader(dset_train, config.datasets.train, train_ids)
    dset_len = len(train_ids) if train_ids is not None else len(dset_train)
    print("Using training dataset of {} samples or {} batches".format(dset_len, len(loader_train)))

    cp_paths = None
    last_epoch = 0
    if 'checkpoint' in config:
        checkpoint_dir = config.checkpoint_dir if 'checkpoint_dir' in config else config.result_path
        cp_paths, last_epoch = config.get_checkpoints(path=checkpoint_dir, tag=config.checkpoint)
        print("Found checkpoint {} for Epoch {}".format(config.checkpoint, last_epoch))
        last_epoch = last_epoch if config.resume_from == -1 else config.resume_from
        # config.epochs = config.epochs - last_epoch if last_epoch else config.epochs

    models = {}
    for name, model in config.model.items():
        logger.info("INFO: Building the {} model".format(name))
        models[name] = build_model(model)

        # Load the checkpoint
        if name in cp_paths:
            models[name].load_state_dict( torch.load( cp_paths[name] ) )
            logger.info("INFO: Loaded model {} checkpoint {}".format(name, cp_paths[name]))

        if len(config.device_ids) > 1:
            models[name] = nn.DataParallel(models[name], device_ids=config.device_ids)

        models[name].to(config.device)
        print(models[name])

        if 'debug' in config and config.debug is True:
            print("*********** {} ************".format(name))
            for name, param in models[name].named_parameters():
                if param.requires_grad:
                    print(name, param.data)

    optimizers = {}
    for name, conf in config.optimizer.items():
        optim_conf = conf.copy()
        del optim_conf["models"]

        model_params = []
        for model_id in conf.models:
            model_params.extend( list(filter(lambda p: p.requires_grad, models[model_id].parameters())) )
        
        logger.info("INFO: Using {} Optimization for {}".format(list(optim_conf.keys())[0], name))
        optimizers[name] = get_optimizer(model_params, optim_conf)

        # Restoring the optimizer breaks because we do not include all parameters in the optimizer state. So if we aren't continuing training then just make a new optimizer
        if name in cp_paths and 'checkpoint_dir' not in config:
            optimizers[name].load_state_dict( torch.load( cp_paths[name] ) )
            logger.info("INFO: Loaded {} optimizer checkpoint {}".format(name, cp_paths[name]))

            for state in optimizers[name].state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.to(config.device)

    losses = {}
    for name, fcns in config.loss.items():
        losses[name] = []
        for l in fcns:
            losses[name].append( get_loss(l) )
            assert losses[name][-1], "Loss function {} for {} could not be found, please check your config".format(l, name)

    if 'logger' in config:
        logger.info("INFO: Initialising the experiment logger")
        exp_logger = get_experiment_logger(config.result_path, config.logger)
        if last_epoch > 0:
            exp_logger.fast_forward(last_epoch, len(loader_train))

    logger.info("INFO: Creating training manager and configuring callbacks")
    trainer = get_trainer(models, optimizers, losses, exp_logger, config)

    trainer_engine = Engine(trainer.train)
    evaluator_engine = Engine(trainer.evaluate)

    trainer.attach("train_loader", loader_train)
    trainer.attach("validation_loader", loader_val)
    trainer.attach("evaluation_engine", evaluator_engine)
    trainer.attach("train_engine", trainer_engine)

    for phase in config.metrics.keys():
        if phase == "train": engine = trainer_engine
        if phase == "validation": engine = evaluator_engine

        for name, metric in config.metrics[phase].items():
            metric = get_metric(metric)
            if metric is not None:
                metric.attach(engine, name)
            else:
                logger.warning("WARNING: Metric {} could not be created for {} phase".format(name, phase))

    # Register default callbacks to run the validation stage
    if loader_val is not None:
        if len(loader_train) > 2000:
            # Validate 4 times an epoch
            num_batch = len(loader_train)//4

            def validate_run(engine):
                if engine.state.iteration % num_batch == 0:
                    evaluator_engine.run(loader_val)

            trainer_engine.add_event_handler(Events.ITERATION_COMPLETED, validate_run)
        else:
            trainer_engine.add_event_handler(Events.EPOCH_COMPLETED, lambda engine: evaluator_engine.run(loader_val))

    # Initialise the Epoch from the checkpoint - this is a hack because Ignite is dumb
    if last_epoch > 0:
        def set_epoch(engine, last_epoch):
            engine.state.epoch = last_epoch

        trainer_engine.add_event_handler(Events.STARTED, set_epoch, last_epoch)


    schedulers = {"batch": {}, "epoch": {}}
    if 'scheduler' in config:
        for sched_name, sched in config.scheduler.items():
            if sched_name in optimizers:
                logger.info("INFO: Setting up LR scheduler for {}".format(sched_name))
                sched_fn, sched_scheme = get_lr_scheduler(optimizers[sched_name], sched)
                assert sched_fn, "Learning Rate scheduler for {} could not be found, please check your config".format(sched_name)
                assert sched_scheme in ["batch", "epoch"], "ERROR: Invalid scheduler scheme, must be either epoch or batch"

                schedulers[sched_scheme][sched_name] = sched_fn

        def epoch_scheduler(engine):
            for name, sched in schedulers["epoch"].items():
                sched.step()

        def batch_scheduler(engine):
            for name, sched in schedulers["batch"].items():
                sched.step()

        trainer_engine.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: batch_scheduler(engine))
        trainer_engine.add_event_handler(Events.EPOCH_COMPLETED, lambda engine: epoch_scheduler(engine))

    if exp_logger is not None:
        trainer_engine.add_event_handler(Events.ITERATION_COMPLETED, exp_logger.log_iteration, phase="train", models=models, optims=optimizers)
        trainer_engine.add_event_handler(Events.EPOCH_COMPLETED, exp_logger.log_epoch, phase="train", models=models, optims=optimizers)
        evaluator_engine.add_event_handler(Events.ITERATION_COMPLETED, exp_logger.log_iteration, phase="evaluate", models=models, optims=optimizers)
        evaluator_engine.add_event_handler(Events.EPOCH_COMPLETED, exp_logger.log_epoch, phase="evaluate", models=models, optims=optimizers)

    if "monitor" in config and config.monitor.early_stopping:
        logger.info("INFO: Enabling early stopping, monitoring {}".format(config.monitor.score))
        score_fn = lambda e: config.monitor.scale * e.state.metrics[config.monitor.score]
        es_handler = EarlyStopping(patience=config.monitor.patience, score_function=score_fn, trainer=trainer_engine)
        evaluator_engine.add_event_handler(Events.COMPLETED, es_handler)

    if "monitor" in config and config.monitor.save_score:
        logger.info("INFO: Saving best model based on {}".format(config.monitor.save_score))
        score_fn = lambda e: config.monitor.save_scale * e.state.metrics[config.monitor.save_score]
        ch_handler = ModelCheckpoint(config.result_path, 'best_checkpoint', score_function=score_fn, score_name=config.monitor.save_score, n_saved=1, require_empty=False, save_as_state_dict=True)
        to_save = dict(models, **optimizers)
        evaluator_engine.add_event_handler(Events.EPOCH_COMPLETED, ch_handler, to_save)

    if config.save_freq > 0:
        ch_handler = ModelCheckpoint(config.result_path, 'checkpoint', save_interval=config.save_freq, n_saved=config.nsave, require_empty=False, save_as_state_dict=True)
        to_save = dict(models, **optimizers)
        trainer_engine.add_event_handler(Events.EPOCH_COMPLETED, ch_handler, to_save)

    # Register custom callbacks with the engines
    if check_if_implemented(trainer, "on_iteration_start"):
        trainer_engine.add_event_handler(Events.ITERATION_STARTED, trainer.on_iteration_start, phase="train")
        evaluator_engine.add_event_handler(Events.ITERATION_STARTED, trainer.on_iteration_start, phase="evaluate")
    if check_if_implemented(trainer, "on_iteration_end"):
        trainer_engine.add_event_handler(Events.ITERATION_COMPLETED, trainer.on_iteration_end, phase="train")
        evaluator_engine.add_event_handler(Events.ITERATION_COMPLETED, trainer.on_iteration_end, phase="evaluate")
    if check_if_implemented(trainer, "on_epoch_start"):
        trainer_engine.add_event_handler(Events.EPOCH_STARTED, trainer.on_epoch_start, phase="train")
        evaluator_engine.add_event_handler(Events.EPOCH_STARTED, trainer.on_epoch_start, phase="evaluate")
    if check_if_implemented(trainer, "on_epoch_end"):
        trainer_engine.add_event_handler(Events.EPOCH_COMPLETED, trainer.on_epoch_end, phase="train")
        evaluator_engine.add_event_handler(Events.EPOCH_COMPLETED, trainer.on_epoch_end, phase="evaluate")

    # Save the config for this experiment to the results directory, once we know the params are good
    config.save()

    def signal_handler(sig, frame):
        print('You pressed Ctrl+C!')
        if exp_logger is not None:
            exp_logger.teardown()
            sys.exit(0)

    signal.signal(signal.SIGINT, signal_handler)

    logger.info("INFO: Starting training...")
    trainer_engine.run(loader_train, max_epochs=config.epochs)

    if exp_logger is not None:
        exp_logger.teardown()
Beispiel #6
0
    def train(self, config, **kwargs):
        """Trains a model on the given configurations.
        :param config:str: A training configuration. Note that all parameters in the config can also be manually adjusted with --ARG=VALUE
        :param **kwargs: parameters to overwrite yaml config
        """

        from pycocoevalcap.cider.cider import Cider

        config_parameters = train_util.parse_config_or_kwargs(config, **kwargs)
        config_parameters["seed"] = self.seed
        zh = config_parameters["zh"]
        outputdir = os.path.join(
            config_parameters["outputpath"], config_parameters["model"],
            "{}_{}".format(
                datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%m'),
                uuid.uuid1().hex))

        # Early init because of creating dir
        checkpoint_handler = ModelCheckpoint(
            outputdir,
            "run",
            n_saved=1,
            require_empty=False,
            create_dir=True,
            score_function=lambda engine: -engine.state.metrics["loss"],
            score_name="loss")

        logger = train_util.genlogger(os.path.join(outputdir, "train.log"))
        # print passed config parameters
        logger.info("Storing files in: {}".format(outputdir))
        train_util.pprint_dict(config_parameters, logger.info)

        vocabulary = torch.load(config_parameters["vocab_file"])
        trainloader, cvloader, info = self._get_dataloaders(config_parameters, vocabulary)
        config_parameters["inputdim"] = info["inputdim"]
        logger.info("<== Estimating Scaler ({}) ==>".format(info["scaler"].__class__.__name__))
        logger.info(
                "Stream: {} Input dimension: {} Vocab Size: {}".format(
                config_parameters["feature_stream"], info["inputdim"], len(vocabulary)))
        train_key2refs = info["train_key2refs"]
        # train_scorer = BatchCider(train_key2refs)
        cv_key2refs = info["cv_key2refs"]
        # cv_scorer = BatchCider(cv_key2refs)

        model = self._get_model(config_parameters, vocabulary)
        model = model.to(device)
        train_util.pprint_dict(model, logger.info, formatter="pretty")
        optimizer = getattr(
            torch.optim, config_parameters["optimizer"]
        )(model.parameters(), **config_parameters["optimizer_args"])
        train_util.pprint_dict(optimizer, logger.info, formatter="pretty")

        # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            # optimizer, **config_parameters["scheduler_args"])
        crtrn_imprvd = train_util.criterion_improver(config_parameters["improvecriterion"])

        def _train_batch(engine, batch):
            model.train()
            with torch.enable_grad():
                optimizer.zero_grad()
                train_scorer = Cider(zh=zh)
                output = self._forward(model, batch, "train", train_mode="scst", 
                                       key2refs=train_key2refs, scorer=train_scorer)
                output["loss"].backward()
                optimizer.step()
                return output

        trainer = Engine(_train_batch)
        RunningAverage(output_transform=lambda x: x["loss"]).attach(trainer, "running_loss")
        pbar = ProgressBar(persist=False, ascii=True)
        pbar.attach(trainer, ["running_loss"])

        key2pred = {}

        def _inference(engine, batch):
            model.eval()
            keys = batch[2]
            with torch.no_grad():
                cv_scorer = Cider(zh=zh)
                output = self._forward(model, batch, "train", train_mode="scst",
                                       key2refs=cv_key2refs, scorer=cv_scorer)
                seqs = output["sampled_seqs"].cpu().numpy()
                for idx, seq in enumerate(seqs):
                    if keys[idx] in key2pred:
                        continue
                    candidate = self._convert_idx2sentence(seq, vocabulary, zh=zh)
                    key2pred[keys[idx]] = [candidate,]
                return output

        evaluator = Engine(_inference)
        RunningAverage(output_transform=lambda x: x["loss"]).attach(trainer, "running_loss")

        metrics = {
            "loss": Average(output_transform=lambda x: x["loss"]),
            "reward": Average(output_transform=lambda x: x["reward"].reshape(-1, 1)),
        }

        for name, metric in metrics.items():
            metric.attach(trainer, name)
            metric.attach(evaluator, name)

        RunningAverage(output_transform=lambda x: x["loss"]).attach(evaluator, "running_loss")
        pbar.attach(evaluator, ["running_loss"])

        # @trainer.on(Events.STARTED)
        # def log_initial_result(engine):
            # evaluator.run(cvloader, max_epochs=1)
            # logger.info("Initial Results - loss: {:<5.2f}\tscore: {:<5.2f}".format(evaluator.state.metrics["loss"], evaluator.state.metrics["score"].item()))


        trainer.add_event_handler(
              Events.EPOCH_COMPLETED, train_util.log_results, evaluator, cvloader,
              logger.info, metrics.keys(), ["loss", "reward", "score"])

        def eval_cv(engine, key2pred, key2refs, scorer):
            # if len(cv_key2refs) == 0:
                # for key, _ in key2pred.items():
                    # cv_key2refs[key] = key2refs[key]
            score, scores = scorer.compute_score(key2refs, key2pred)
            engine.state.metrics["score"] = score
            key2pred.clear()

        evaluator.add_event_handler(
            Events.EPOCH_COMPLETED, eval_cv, key2pred, cv_key2refs, Cider(zh=zh))

        evaluator.add_event_handler(
            Events.EPOCH_COMPLETED, train_util.save_model_on_improved, crtrn_imprvd,
            "score", {
                "model": model,
                "config": config_parameters,
                "scaler": info["scaler"]
            }, os.path.join(outputdir, "saved.pth"))

        evaluator.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler, {
                "model": model,
            }
        )

        trainer.run(trainloader, max_epochs=config_parameters["epochs"])
        return outputdir
Beispiel #7
0
def main():
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    args.o.mkdir(parents=True, exist_ok=True)

    train_dataset = MNIST(root=".",
                          download=True,
                          train=True,
                          transform=mnist_transform)

    test_dataset = MNIST(root=".",
                         download=True,
                         train=False,
                         transform=mnist_transform)

    train_loader = DataLoader(train_dataset, args.b, shuffle=True)
    test_loader = DataLoader(test_dataset, args.b)

    if args.m == "fc":
        net = FullyConnectedVAE(28 * 28, args.zdim).to(device)
    elif args.m == "cnn":
        net = CNNVAE(1, args.zdim).to(device)

    opt = torch.optim.Adam(net.parameters())
    trainer = Engine(VAETrainer(net, opt, device))
    attach_metrics(trainer)

    evaluator = create_evaluator(net, device)

    train_logger = StateLogger(trainer)
    test_logger = StateLogger(evaluator)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, train_logger)
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              EvaluationRunner(evaluator, test_loader))
    metric_keys = ("loss", "kl_div", "recon_loss")
    trainer.add_event_handler(Events.EPOCH_COMPLETED, test_logger)
    trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        Plotter(train_logger, metric_keys, args.o / "train_loss.pdf"))
    trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        Plotter(test_logger, metric_keys, args.o / "test_loss.pdf"))
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              LogPrinter(train_logger, metric_keys))
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              LogPrinter(test_logger, metric_keys))
    trainer.add_event_handler(Events.COMPLETED,
                              ModelSaver(net, args.o / "model.pt"))

    trainer.run(train_loader, max_epochs=args.e)
Beispiel #8
0
def main(config):
    assert validate_config(
        config), "ERROR: Config file is invalid. Please see log for details."

    logger.info("INFO: {}".format(config.toDict()))

    if config.device == "cpu" and torch.cuda.is_available():
        logger.warning("WARNING: Not using the GPU")

    assert 'test' in config.datasets, "ERROR: Not test dataset is specified in the config. Don't know how to proceed."

    logger.info("INFO: Creating datasets and dataloaders...")

    config.datasets.test.update({
        'shuffle': False,
        'augment': False,
        'workers': 1
    })

    # Create the training dataset
    dset_test = create_dataset(config.datasets.test)

    config.datasets.test.update({'batch_size': 1})
    loader_test = get_data_loader(dset_test, config.datasets.test)

    logger.info("INFO: Running inference on {} samples".format(len(dset_test)))

    logger.info("INFO: Building the {} model".format(config.model.type))
    model = build_model(config.model)

    m_cp_path, _ = config.get_checkpoint_path()

    assert m_cp_path, "Could not find a checkpoint for this model, check your config and try again"
    model.load_state_dict(torch.load(m_cp_path))
    logger.info("INFO: Loaded model checkpoint {}".format(m_cp_path))

    model = model.to(config.device)

    if 'input_size' in config:
        summary(model,
                input_size=config.input_size,
                device=config.device,
                unpack_inputs=True)
    else:
        print(model)

    if 'logger' in config:
        logger.info("INFO: Initialising the experiment logger")
        exp_logger = get_experiment_logger(config)

    logger.info("INFO: Creating training manager and configuring callbacks")
    trainer = get_trainer(model, None, None, None, config)

    evaluator_engine = Engine(trainer.evaluate)

    trainer.attach("test_loader", loader_test)
    trainer.attach("evaluation_engine", evaluator_engine)

    if 'metrics' in config:
        for name, metric in config.metrics.items():
            metric = get_metric(metric)
            if metric is not None:
                metric.attach(evaluator_engine, name)
            else:
                logger.warning(
                    "WARNING: Metric {} could not be created".format(name))

    # Register custom callbacks with the engines
    if check_if_implemented(trainer, "on_iteration_start"):
        evaluator_engine.add_event_handler(Events.ITERATION_STARTED,
                                           trainer.on_iteration_start,
                                           phase="test")
    if check_if_implemented(trainer, "on_iteration_end"):
        evaluator_engine.add_event_handler(Events.ITERATION_COMPLETED,
                                           trainer.on_iteration_end,
                                           phase="test")
    if check_if_implemented(trainer, "on_epoch_start"):
        evaluator_engine.add_event_handler(Events.EPOCH_STARTED,
                                           trainer.on_epoch_start,
                                           phase="test")
    if check_if_implemented(trainer, "on_epoch_end"):
        evaluator_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           trainer.on_epoch_end,
                                           phase="test")

    evaluator_engine.add_event_handler(Events.ITERATION_COMPLETED,
                                       accumulate_predictions)
    # evaluator_engine.add_event_handler(Events.EPOCH_COMPLETED, save_predictions, base_path=os.path.splitext(m_cp_path)[0], state_dict=accumulate_predictions.state_dict)

    logger.info("INFO: Starting inference...")
    evaluator_engine.run(loader_test)

    save_predictions(evaluator_engine,
                     os.path.splitext(m_cp_path)[0],
                     accumulate_predictions.state_dict)
Beispiel #9
0
    def train(self, config, **kwargs):
        """Trains a model on the given configurations.
        :param config:str: A training configuration. Note that all parameters in the config can also be manually adjusted with --ARG=VALUE
        :param **kwargs: parameters to overwrite yaml config
        """

        from pycocoevalcap.cider.cider import Cider
        from pycocoevalcap.spider.spider import Spider

        conf = train_util.parse_config_or_kwargs(config, **kwargs)
        conf["seed"] = self.seed
        zh = conf["zh"]
        outputdir = os.path.join(
            conf["outputpath"], conf["modelwrapper"],
            # "{}_{}".format(
                # datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%m'),
                # uuid.uuid1().hex))
            conf["remark"], "seed_{}".format(self.seed)
        )

        # Early init because of creating dir
        checkpoint_handler = ModelCheckpoint(
            outputdir,
            "run",
            n_saved=1,
            require_empty=False,
            create_dir=True,
            score_function=lambda engine: engine.state.metrics["score"],
            score_name="score")

        logger = train_util.genlogger(os.path.join(outputdir, "train.log"))
        # print passed config parameters
        logger.info("Storing files in: {}".format(outputdir))
        train_util.pprint_dict(conf, logger.info)

        vocabulary = torch.load(conf["vocab_file"])
        train_loader, val_loader, info = self._get_dataloaders(conf, vocabulary)
        conf["inputdim"] = info["inputdim"]
        logger.info("<== Estimating Scaler ({}) ==>".format(info["scaler"].__class__.__name__))
        logger.info(
                "Feature: {} Input dimension: {} Vocab Size: {}".format(
                conf["feature_file"], info["inputdim"], len(vocabulary)))
        train_key2refs = info["train_key2refs"]
        val_key2refs = info["val_key2refs"]

        model = self._get_model(conf, vocabulary)
        model = model.to(self.device)
        train_util.pprint_dict(model, logger.info, formatter="pretty")
        optimizer = getattr(
            torch.optim, conf["optimizer"]
        )(model.parameters(), **conf["optimizer_args"])
        train_util.pprint_dict(optimizer, logger.info, formatter="pretty")

        crtrn_imprvd = train_util.criterion_improver(conf["improvecriterion"])

        scorer_dict = {"cider": Cider(zh=zh), "spider": Spider()}
        if "train_scorer" not in conf:
            conf["train_scorer"] = "cider"
        train_scorer = scorer_dict[conf["train_scorer"]]
        def _train_batch(engine, batch):
            # import pdb; pdb.set_trace()
            # set num batch tracked?
            model.train()
            with torch.enable_grad():
                optimizer.zero_grad()
                # train_scorer = scorer_dict[conf["train_scorer"]]
                output = self._forward(model, batch, "train", 
                                       key2refs=train_key2refs, 
                                       scorer=train_scorer,
                                       vocabulary=vocabulary)
                output["loss"].backward()
                optimizer.step()
                return output

        trainer = Engine(_train_batch)
        RunningAverage(output_transform=lambda x: x["loss"]).attach(trainer, "running_loss")
        pbar = ProgressBar(persist=False, ascii=True)
        pbar.attach(trainer, ["running_loss"])

        key2pred = {}

        def _inference(engine, batch):
            model.eval()
            keys = batch[2]
            with torch.no_grad():
                # val_scorer = Cider(zh=zh)
                # output = self._forward(model, batch, "train", 
                                       # key2refs=val_key2refs, scorer=val_scorer)
                # seqs = output["greedy_seqs"].cpu().numpy()
                output = self._forward(model, batch, "validation")
                seqs = output["seqs"].cpu().numpy()
                for idx, seq in enumerate(seqs):
                    if keys[idx] in key2pred:
                        continue
                    candidate = self._convert_idx2sentence(seq, vocabulary, zh=zh)
                    key2pred[keys[idx]] = [candidate,]
                return output

        evaluator = Engine(_inference)
        RunningAverage(output_transform=lambda x: x["loss"]).attach(trainer, "running_loss")

        metrics = {
            "loss": Average(output_transform=lambda x: x["loss"]),
            "reward": Average(output_transform=lambda x: x["reward"].reshape(-1, 1)),
            # "score": Average(output_transform=lambda x: x["score"].reshape(-1, 1)),
        }

        for name, metric in metrics.items():
            metric.attach(trainer, name)
            # metric.attach(evaluator, name)

        # RunningAverage(output_transform=lambda x: x["loss"]).attach(evaluator, "running_loss")
        # pbar.attach(evaluator, ["running_loss"])
        pbar.attach(evaluator) 

        trainer.add_event_handler(
              Events.EPOCH_COMPLETED, train_util.log_results, evaluator, val_loader,
              logger.info, metrics.keys(), ["score"])

        def eval_val(engine, key2pred, key2refs, scorer):
            score, scores = scorer.compute_score(key2refs, key2pred)
            engine.state.metrics["score"] = score
            key2pred.clear()

        evaluator.add_event_handler(
            Events.EPOCH_COMPLETED, eval_val, key2pred, val_key2refs, Cider(zh=zh))

        evaluator.add_event_handler(
            Events.EPOCH_COMPLETED, train_util.save_model_on_improved, crtrn_imprvd,
            "score", {
                "model": model.state_dict(),
                "config": conf,
                "scaler": info["scaler"]
            }, os.path.join(outputdir, "saved.pth"))

        evaluator.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler, {
                "model": model,
            }
        )

        trainer.run(train_loader, max_epochs=conf["epochs"])
        return outputdir
Beispiel #10
0
def main(config):
    assert validate_config(
        config), "ERROR: Config file is invalid. Please see log for details."

    logger.info("INFO: {}".format(config.toDict()))

    # Set the random number generator seed for torch, as we use their dataloaders this will ensure shuffle is constant
    # Remeber to seed custom datasets etc with the same seed
    if config.seed > 0:
        torch.cuda.manual_seed_all(config.seed)
        torch.manual_seed(config.seed)

    if config.device == "cpu" and torch.cuda.is_available():
        logger.warning("WARNING: Not using the GPU")

    logger.info("INFO: Creating datasets and dataloaders...")
    # Create the training dataset
    dset_train = create_dataset(config.datasets.train)
    # Esnure we have a full config for validation, this means we don't need t specify everything in the config file
    # only the differences
    config_val = config.datasets.train.copy()
    config_val.update(config.datasets.validation)

    # If the validation config has a parameter called split then we ask the training dset for the validation dataset
    # it should be noted that you shouldn't shuffle the dataset in the init of the train dataset if this is the case
    # as only on get_validation_split will we know how to split the data. Unless shuffling is deterministic.
    if 'validation' in config.datasets:
        if 'split' in config.datasets.validation:
            dset_val = dset_train.get_validation_split(config_val)
        else:
            dset_val = create_dataset(config_val)
    else:
        logger.warning("WARNING: No validation dataset was specified")
        dset_val = None
        loader_val = None

    loader_train = get_data_loader(dset_train, config.datasets.train)

    if dset_val is not None:
        loader_val = get_data_loader(dset_val, config_val)

    logger.info("INFO: Building the {} model".format(config.model.type))
    model = build_model(config.model)

    m_cp_path, o_cp_path = config.get_checkpoint_path()

    if config.resume_from >= 0:
        assert m_cp_path, "Could not find a checkpoint for this model, check your config and try again"
        model.load_state_dict(torch.load(m_cp_path))

    model = model.to(config.device)

    logger.info("INFO: Using {} Optimization".format(config.optimizer.type))
    optimizer = get_optimizer(model.parameters(), config.optimizer)

    if config.resume_from >= 0:
        assert o_cp_path, "Could not find a checkpoint for the optimizer, please check your results folder"
        optimizer.load_state_dict(torch.load(o_cp_path))

    loss_fn = get_loss(config.loss)
    assert loss_fn, "Loss function {} could not be found, please check your config".format(
        config.loss)

    scheduler = None
    if 'scheduler' in config:
        logger.info("INFO: Setting up LR scheduler {}".format(
            config.scheduler.type))
        scheduler = get_lr_scheduler(optimizer, config.scheduler)
        assert scheduler, "Learning Rate scheduler function {} could not be found, please check your config".format(
            config.scheduler.type)

    if 'logger' in config:
        logger.info("INFO: Initialising the experiment logger")
        exp_logger = get_experiment_logger(config)

    logger.info("INFO: Creating training manager and configuring callbacks")
    trainer = get_trainer(model, optimizer, loss_fn, exp_logger, config)

    trainer_engine = Engine(trainer.train)
    evaluator_engine = Engine(trainer.evaluate)

    trainer.attach("train_loader", loader_train)
    trainer.attach("validation_loader", loader_val)
    trainer.attach("evaluation_engine", evaluator_engine)
    trainer.attach("train_engine", trainer_engine)

    if 'metrics' in config:
        for name, metric in config.metrics.items():
            metric = get_metric(metric)
            if metric is not None:
                metric.attach(evaluator_engine, name)
            else:
                logger.warning(
                    "WARNING: Metric {} could not be created".format(name))

    # Register default callbacks
    if exp_logger is not None:
        trainer_engine.add_event_handler(Events.ITERATION_COMPLETED,
                                         exp_logger.log_iteration,
                                         phase="train",
                                         model=model)
        trainer_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                         exp_logger.log_epoch,
                                         phase="train",
                                         model=model)
        evaluator_engine.add_event_handler(Events.ITERATION_COMPLETED,
                                           exp_logger.log_iteration,
                                           phase="evaluate",
                                           model=model)
        evaluator_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           exp_logger.log_epoch,
                                           phase="evaluate",
                                           model=model)

    if loader_val is not None:
        trainer_engine.add_event_handler(
            Events.EPOCH_COMPLETED,
            lambda engine: evaluator_engine.run(loader_val))

    if scheduler is not None:
        if config.scheduler.scheme == "batch":
            scheduler_event = Events.ITERATION_COMPLETED
        elif config.scheduler.scheme == "epoch":
            scheduler_event = Events.EPOCH_COMPLETED
        else:
            logger.error(
                "ERROR: Invalid scheduler scheme, must be either epoch or batch"
            )
            return 0

        trainer_engine.add_event_handler(scheduler_event,
                                         lambda engine: scheduler.step())

    if config.monitor.early_stopping:
        logger.info("INFO: Enabling early stopping, monitoring {}".format(
            config.monitor.score))
        score_fn = lambda e: config.monitor.scale * e.state.metrics[
            config.monitor.score]
        es_handler = EarlyStopping(patience=config.monitor.patience,
                                   score_function=score_fn,
                                   trainer=trainer_engine)
        evaluator_engine.add_event_handler(Events.COMPLETED, es_handler)

    if config.save_freq > 0:
        ch_path = config.result_path
        ch_handler = ModelCheckpoint(config.result_path,
                                     'checkpoint',
                                     save_interval=config.save_freq,
                                     n_saved=4,
                                     require_empty=False,
                                     save_as_state_dict=True)
        trainer_engine.add_event_handler(Events.EPOCH_COMPLETED, ch_handler,
                                         {'model': model})

    # Register custom callbacks with the engines
    if check_if_implemented(trainer, "on_iteration_start"):
        trainer_engine.add_event_handler(Events.ITERATION_STARTED,
                                         trainer.on_iteration_start,
                                         phase="train")
        evaluator_engine.add_event_handler(Events.ITERATION_STARTED,
                                           trainer.on_iteration_start,
                                           phase="evaluate")
    if check_if_implemented(trainer, "on_iteration_end"):
        trainer_engine.add_event_handler(Events.ITERATION_COMPLETED,
                                         trainer.on_iteration_end,
                                         phase="train")
        evaluator_engine.add_event_handler(Events.ITERATION_COMPLETED,
                                           trainer.on_iteration_end,
                                           phase="evaluate")
    if check_if_implemented(trainer, "on_epoch_start"):
        trainer_engine.add_event_handler(Events.EPOCH_STARTED,
                                         trainer.on_epoch_start,
                                         phase="train")
        evaluator_engine.add_event_handler(Events.EPOCH_STARTED,
                                           trainer.on_epoch_start,
                                           phase="evaluate")
    if check_if_implemented(trainer, "on_epoch_end"):
        trainer_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                         trainer.on_epoch_end,
                                         phase="train")
        evaluator_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           trainer.on_epoch_end,
                                           phase="evaluate")

    # Save the config for this experiment to the results directory, once we know the params are good
    config.save()

    logger.info("INFO: Starting training...")
    trainer_engine.run(loader_train, max_epochs=config.epochs)
Beispiel #11
0
class Trainer(abc.ABC):
    """
    Abstract class that manages the training, and evaluation of a network.
    """
    def __init__(self, ds_path, save_dir, view_radius, device):
        self.ds_path = ds_path
        self.device = device
        self.view_radius = view_radius
        self.net, self.net_name = self.define_net()
        self.net.to(device)
        self.optim = torch.optim.Adam(self.net.parameters())
        self.trainer = Engine(self.train_update_func)
        self.evaluator = Engine(self.val_update_func)
        self.saver = ModelCheckpoint(save_dir,
                                     self.net_name,
                                     save_interval=1,
                                     n_saved=20,
                                     require_empty=False)
        self._add_metrics()
        self._add_event_handlers()

    @abc.abstractmethod
    def define_net(self):
        """
        Implement this to return the network that will be used by the trainer.

        Returns:
            Tuple containing the network, and a string name for the network.
        """
        pass

    @abc.abstractmethod
    def get_loss(self, output, target):
        """
        Implement this to return the loss of the network.
        """
        pass

    @abc.abstractmethod
    def train_update_func(self, engine, batch):
        """
        The update step taken for each train iteration.
        """
        pass

    @abc.abstractmethod
    def val_update_func(self, engine, batch):
        """
        The step taken at each validation iteration.
        """
        pass

    @abc.abstractmethod
    def get_shelve_dataset(self, ds_path):
        """
        Return the torch dataset to use.
        """
        pass

    @abc.abstractmethod
    def get_test_action(self, network_output):
        """
        Returns the action to take for the ant given the network output.
        """
        pass

    def train(self, max_epochs, batch_size=32):
        self.net.train()
        ds_path = os.path.join(self.ds_path, 'train')
        with self.get_shelve_dataset(ds_path) as ds:
            dl = DataLoader(ds,
                            batch_size=batch_size,
                            shuffle=True,
                            collate_fn=netutils.collate_namespace)
            loader = GameSpaceToAntSpaceTransformer(dl, self.view_radius,
                                                    self.device)
            self.trainer.run(loader, max_epochs=max_epochs)

    def validate(self):
        self.net.eval()
        ds_path = os.path.join(self.ds_path, 'val')
        with self.get_shelve_dataset(ds_path) as ds, torch.no_grad():
            dl = DataLoader(ds,
                            batch_size=512,
                            shuffle=False,
                            collate_fn=netutils.collate_namespace)
            loader = GameSpaceToAntSpaceTransformer(dl, self.view_radius,
                                                    self.device)
            self.evaluator.run(loader, max_epochs=1)

    def test(self, map_file=None, opponent=None, opponent_plays_first=False):
        map_file = map_file or os.path.join(
            os.path.dirname(os.path.abspath(__file__)),
            '../../../../ants/maps/example/tutorial_p2_1.map'
            # '../../../../ants/maps/cell_maze/cell_maze_p02_04.map'
        )
        opponent = opponent or utils.enemybots.SampleBots.greedy_bot()
        if opponent_plays_first:
            agent_num, opponent_num = 1, 0
            player_names = [opponent.name, self.net_name]
        else:
            agent_num, opponent_num = 0, 1
            player_names = [self.net_name, opponent.name]

        opts = utils.antsgym.AntsEnvOptions()   \
            .set_map_file(map_file)
        reward_func = utils.reward.ScoreFunc()
        env = utils.antsgym.AntsEnv(opts, [], reward_func, player_names)

        opponent = utils.antsgym.SampleAgent(opponent, env, opponent_num)
        opponent.setup()

        self.net.eval()
        state = env.reset()
        while True:
            transformer = GameSpaceToAntSpaceTransformer(
                [SimpleNamespace(state=[state[[agent_num]]])],
                self.view_radius, self.device)
            batch, batch_locs = [], []
            for substate in transformer:
                batch.append(substate.state)
                batch_locs.append(substate.locs[0])
            batch = torch.cat(batch)

            with torch.no_grad():
                out = self.net(batch)
                sample = self.get_test_action(out)
            agent_acts = np.zeros((1, env.game.height, env.game.width))
            for i, (_, row, col) in enumerate(batch_locs):
                agent_acts[0, row, col] = sample[i]

            opponent.update_map(state[opponent_num])
            opponent_acts = opponent.get_moves()

            if agent_num == 0:
                actions = np.concatenate([agent_acts, opponent_acts], axis=0)
            else:
                actions = np.concatenate([opponent_acts, agent_acts], axis=0)
            state, reward, done, info = env.step(actions)
            if done:
                break
        return env

    def benchmark(self, map_files, num_trials):
        opponents = [
            lambda: utils.enemybots.CmdBot.xanthis_bot(),
            lambda: utils.enemybots.SampleBots.random_bot(),
            lambda: utils.enemybots.SampleBots.hunter_bot(),
            lambda: utils.enemybots.SampleBots.greedy_bot()
        ]
        result = {}
        for map_file in map_files:
            result[map_file] = {}
            for opponent in opponents:
                for trial in range(num_trials):
                    o = opponent()
                    env = self.test(map_file, o, False)
                    game_results = env.get_game_result()
                    if o.name not in result[map_file]:
                        result[map_file][o.name] = {
                            'wins': [0, 0],
                            'losses': [0, 0],
                            'turns': [0, 0]
                        }
                    if game_results['score'][0] > game_results['score'][1]:
                        result[map_file][o.name]['wins'][0] += 1
                    elif game_results['score'][0] < game_results['score'][1]:
                        result[map_file][o.name]['losses'][0] += 1
                    result[map_file][o.name]['turns'][0] += len(
                        game_results['replaydata']['scores'][0])
                for trial in range(num_trials):
                    o = opponent()
                    env = self.test(map_file, o, True)
                    game_results = env.get_game_result()
                    if game_results['score'][1] > game_results['score'][0]:
                        result[map_file][o.name]['wins'][1] += 1
                    elif game_results['score'][1] < game_results['score'][0]:
                        result[map_file][o.name]['losses'][1] += 1
                    result[map_file][o.name]['turns'][1] += len(
                        game_results['replaydata']['scores'][0])
                result[map_file][o.name]['turns'][0] /= num_trials
                result[map_file][o.name]['turns'][1] /= num_trials
        return result

    def restore_net(self, net_path, epoch_num):
        """
        Loads the network parameters from a file.

        Args:
            net_path (str): The path to the pickled network params.
            epoch_num (int): The epoch at which the net was saved (in order to continue from here when saving later).
        """
        self.net.load_state_dict(torch.load(net_path))
        self.saver._iteration = epoch_num
        return self

    def _add_metrics(self):
        train_loss = RunningAverage(Loss(self.get_loss))
        train_loss.attach(self.trainer, 'avg_train_loss')

        val_loss = Loss(self.get_loss)
        val_loss.attach(self.evaluator, 'val_loss')

    def _add_event_handlers(self):
        self.trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                       self._print_train_metrics)
        self.trainer.add_event_handler(Events.EPOCH_COMPLETED,
                                       lambda e: self.validate())
        self.trainer.add_event_handler(Events.EPOCH_COMPLETED, self.saver,
                                       {'net': self.net})
        self.evaluator.add_event_handler(Events.EPOCH_COMPLETED,
                                         self._print_val_metrics)

    def _print_train_metrics(self, engine):
        if engine.state.iteration % 100 == 0:
            print(
                f'Epoch {engine.state.epoch}, iter {engine.state.iteration}. Loss: {engine.state.metrics["avg_train_loss"]}'
            )

    def _print_val_metrics(self, engine):
        print(f'Val loss: {engine.state.metrics["val_loss"]}.')
Beispiel #12
0
def train(model,
          model_name,
          train_dataloader,
          test_dataloader,
          trainer_name='bb_detection'):
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    def _prepare_batch(batch, device=None, non_blocking=False):
        """Prepare batch for training: pass to a device with options.
        """
        images, boxes = batch
        images = [image.to(device) for image in images]
        targets = [{
            'boxes': box.to(device),
            'labels': torch.ones((1), dtype=torch.int64).to(device)
        } for box in boxes]
        return images, targets

    writer = SummaryWriter(log_dir=path.join('logs', trainer_name, model_name))
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                              factor=0.5,
                                                              patience=250)

    def _update(engine, batch):
        model.train()
        optimizer.zero_grad()
        x, y = _prepare_batch(batch, device=device)

        loss_dict = model(x, y)
        losses = sum(loss for loss in loss_dict.values())
        loss_value = losses.item()

        losses.backward()
        optimizer.step()
        return loss_value

    trainer = Engine(_update)
    evaluator = create_supervised_evaluator(model,
                                            prepare_batch=_prepare_batch,
                                            metrics={'iou': IOUMetric()},
                                            device=device)

    if path.exists(f'{trainer_name}_{model_name}_checkpoint.pt'):
        checkpoint = torch.load(f'{trainer_name}_{model_name}_checkpoint.pt')
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        trainer.load_state_dict(checkpoint['trainer'])

    def early_stop_score_function(engine):
        val_acc = engine.state.metrics['iou']
        return val_acc

    early_stop_handler = EarlyStopping(
        patience=20, score_function=early_stop_score_function, trainer=trainer)
    evaluator.add_event_handler(Events.COMPLETED, early_stop_handler)

    checkpoint_handler = ModelCheckpoint(f'models/{trainer_name}/{model_name}',
                                         model_name,
                                         n_saved=20,
                                         create_dir=True)
    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=100),
                              checkpoint_handler, {
                                  'model': model,
                                  'optimizer': optimizer,
                                  'trainer': trainer
                              })

    @trainer.on(Events.ITERATION_COMPLETED(every=10))
    def log_training_loss(trainer):
        lr = optimizer.param_groups[0]['lr']
        print("Epoch[{}]: {} - Loss: {:.4f}, Lr: {}".format(
            trainer.state.epoch, trainer.state.iteration, trainer.state.output,
            lr))
        writer.add_scalar("training/loss", trainer.state.output,
                          trainer.state.iteration)

    @trainer.on(Events.ITERATION_COMPLETED(every=100))
    def log_training_results(trainer):
        evaluator.run(test_dataloader)
        metrics = evaluator.state.metrics
        print("Training Results - Epoch[{}]: {} - Avg IOU: {:.4f}".format(
            trainer.state.epoch, trainer.state.iteration, metrics['iou']))
        writer.add_scalar("training/avg_iou", metrics['iou'],
                          trainer.state.iteration)

        model.eval()
        test_data = iter(test_dataloader)
        x, y = _prepare_batch(next(test_data), device)
        y_pred = model(x)

        for image, output in zip(x, y_pred):
            writer.add_image_with_boxes("training/example_result", image,
                                        output['boxes'],
                                        trainer.state.iteration)
            break
        model.train()

    @trainer.on(Events.ITERATION_COMPLETED(every=10))
    def step_lr(trainer):
        lr_scheduler.step(trainer.state.output)

    @trainer.on(Events.ITERATION_COMPLETED(every=100))
    def read_lr_from_file(trainer):
        if path.exists('lr.txt'):
            with open('lr.txt', 'r', encoding='utf-8') as f:
                lr = float(f.read())
            for group in optimizer.param_groups:
                group['lr'] = lr

    trainer.run(train_dataloader, max_epochs=100)
Beispiel #13
0
    def train(self, config, **kwargs):
        """Trains a model on the given configurations.
        :param config: A training configuration. Note that all parameters in the config can also be manually adjusted with --ARG=VALUE
        :param **kwargs: parameters to overwrite yaml config
        """
        from pycocoevalcap.cider.cider import Cider

        conf = train_util.parse_config_or_kwargs(config, **kwargs)
        conf["seed"] = self.seed
        
        assert "distributed" in conf

        if conf["distributed"]:
            torch.distributed.init_process_group(backend="nccl")
            self.local_rank = torch.distributed.get_rank()
            self.world_size = torch.distributed.get_world_size()
            assert kwargs["local_rank"] == self.local_rank
            torch.cuda.set_device(self.local_rank)
            self.device = torch.device("cuda", self.local_rank)
            # self.group = torch.distributed.new_group()

        if not conf["distributed"] or not self.local_rank:
            outputdir = str(
                Path(conf["outputpath"]) / 
                conf["model"] /
                # "{}_{}".format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%m"),
                               # uuid.uuid1().hex)
                conf["remark"] /
                "seed_{}".format(self.seed)
            )

            Path(outputdir).mkdir(parents=True, exist_ok=True)
            # # Early init because of creating dir
            # checkpoint_handler = ModelCheckpoint(
                # outputdir,
                # "run",
                # n_saved=1,
                # require_empty=False,
                # create_dir=False,
                # score_function=lambda engine: engine.state.metrics["score"],
                # score_name="score")

            logger = train_util.genlogger(str(Path(outputdir) / "train.log"))
            # print passed config parameters
            if "SLURM_JOB_ID" in os.environ:
                logger.info("Slurm job id: {}".format(os.environ["SLURM_JOB_ID"]))
            logger.info("Storing files in: {}".format(outputdir))
            train_util.pprint_dict(conf, logger.info)

        zh = conf["zh"]
        vocabulary = pickle.load(open(conf["vocab_file"], "rb"))
        dataloaders = self._get_dataloaders(conf, vocabulary)
        train_dataloader = dataloaders["train_dataloader"]
        val_dataloader = dataloaders["val_dataloader"]
        val_key2refs = dataloaders["val_key2refs"]
        data_dim = train_dataloader.dataset.data_dim
        conf["input_dim"] = data_dim
        if not conf["distributed"] or not self.local_rank:
            feature_data = conf["h5_csv"] if "h5_csv" in conf else conf["train_h5_csv"]
            logger.info(
                "Feature: {} Input dimension: {} Vocab Size: {}".format(
                    feature_data, data_dim, len(vocabulary)))

        model = self._get_model(conf, len(vocabulary))
        model = model.to(self.device)
        if conf["distributed"]:
            model = torch.nn.parallel.distributed.DistributedDataParallel(
                model, device_ids=[self.local_rank,], output_device=self.local_rank,
                find_unused_parameters=True)
        optimizer = getattr(
            torch.optim, conf["optimizer"]
        )(model.parameters(), **conf["optimizer_args"])

        if not conf["distributed"] or not self.local_rank:
            train_util.pprint_dict(model, logger.info, formatter="pretty")
            train_util.pprint_dict(optimizer, logger.info, formatter="pretty")

        if conf["label_smoothing"]:
            criterion = train_util.LabelSmoothingLoss(len(vocabulary), smoothing=conf["smoothing"])
        else:
            criterion = torch.nn.CrossEntropyLoss().to(self.device)
        crtrn_imprvd = train_util.criterion_improver(conf['improvecriterion'])

        def _train_batch(engine, batch):
            if conf["distributed"]:
                train_dataloader.sampler.set_epoch(engine.state.epoch)
            model.train()
            with torch.enable_grad():
                optimizer.zero_grad()
                output = self._forward(
                    model, batch, "train",
                    ss_ratio=conf["ss_args"]["ss_ratio"]
                )
                loss = criterion(output["packed_logits"], output["targets"]).to(self.device)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), conf["max_grad_norm"])
                optimizer.step()
                output["loss"] = loss.item()
                return output

        trainer = Engine(_train_batch)
        RunningAverage(output_transform=lambda x: x["loss"]).attach(trainer, "running_loss")
        pbar = ProgressBar(persist=False, ascii=True, ncols=100)
        pbar.attach(trainer, ["running_loss"])

        key2pred = {}

        def _inference(engine, batch):
            model.eval()
            keys = batch[0]
            with torch.no_grad():
                output = self._forward(model, batch, "validation")
                seqs = output["seqs"].cpu().numpy()
                for (idx, seq) in enumerate(seqs):
                    candidate = self._convert_idx2sentence(seq, vocabulary, zh)
                    key2pred[keys[idx]] = [candidate,]
                return output

        metrics = {
            "loss": Loss(criterion, output_transform=lambda x: (x["packed_logits"], x["targets"])),
            "accuracy": Accuracy(output_transform=lambda x: (x["packed_logits"], x["targets"])),
        }
        for name, metric in metrics.items():
            metric.attach(trainer, name)

        evaluator = Engine(_inference)

        def eval_val(engine, key2pred, key2refs):
            scorer = Cider(zh=zh)
            score_output = self._eval_prediction(key2refs, key2pred, [scorer])
            engine.state.metrics["score"] = score_output["CIDEr"]
            key2pred.clear()

        evaluator.add_event_handler(
            Events.EPOCH_COMPLETED, eval_val, key2pred, val_key2refs)

        pbar.attach(evaluator)

        # Learning rate scheduler
        if "scheduler" in conf:
            try:
                scheduler = getattr(torch.optim.lr_scheduler, conf["scheduler"])(
                    optimizer, **conf["scheduler_args"])
            except AttributeError:
                import utils.lr_scheduler
                if conf["scheduler"] == "ExponentialDecayScheduler":
                    conf["scheduler_args"]["total_iters"] = len(train_dataloader) * conf["epochs"]
                scheduler = getattr(utils.lr_scheduler, conf["scheduler"])(
                    optimizer, **conf["scheduler_args"])
            if scheduler.__class__.__name__ in ["StepLR", "ReduceLROnPlateau", "ExponentialLR", "MultiStepLR"]:
                evaluator.add_event_handler(
                    Events.EPOCH_COMPLETED, train_util.update_lr,
                    scheduler, "score")
            else:
                trainer.add_event_handler(
                    Events.ITERATION_COMPLETED, train_util.update_lr, scheduler, None)
        
        # Scheduled sampling
        if conf["ss"]:
            trainer.add_event_handler(
                Events.GET_BATCH_COMPLETED, train_util.update_ss_ratio, conf, len(train_dataloader))

        #########################
        # Events for main process: mostly logging and saving
        #########################
        if not conf["distributed"] or not self.local_rank:
            # logging training and validation loss and metrics
            trainer.add_event_handler(
                Events.EPOCH_COMPLETED, train_util.log_results, optimizer, evaluator, val_dataloader,
                logger.info, metrics.keys(), ["score"])
            # saving best model
            evaluator.add_event_handler(
                Events.EPOCH_COMPLETED, train_util.save_model_on_improved, crtrn_imprvd,
                "score", {
                    "model": model.state_dict() if not conf["distributed"] else model.module.state_dict(),
                    # "config": conf,
                    "optimizer": optimizer.state_dict(),
                    "lr_scheduler": scheduler.state_dict()
                }, str(Path(outputdir) / "saved.pth")
            )
            # regular checkpoint
            checkpoint_handler = ModelCheckpoint(
                outputdir,
                "run",
                n_saved=1,
                require_empty=False,
                create_dir=False,
                score_function=lambda engine: engine.state.metrics["score"],
                score_name="score")
            evaluator.add_event_handler(
                Events.EPOCH_COMPLETED, checkpoint_handler, {
                    "model": model,
                }
            )
            # dump configuration
            train_util.store_yaml(conf, str(Path(outputdir) / "config.yaml"))

        #########################
        # Start training
        #########################
        trainer.run(train_dataloader, max_epochs=conf["epochs"])
        if not conf["distributed"] or not self.local_rank:
            return outputdir
Beispiel #14
0
def run() -> None:
    transform = Compose(
        [
            Resize(gc.network.input_size),
            ToTensor(),
            Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]
        # [Resize(gc.network.input_size), ToTensor(), Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
    )
    cudnn.benchmark = True

    gc.check_dataset_path(is_show=True)
    dataset = tutils.CreateDataset(gc)  # train, unknown, known

    train_loader, unknown_loader, known_loader = dataset.get_dataloader(
        transform)

    if gc.option.is_save_config:
        dataset.write_config()  # write config of model
    del dataset.all_list

    classes = list(dataset.classes.values())
    device = gc.network.device

    print(f"Building network by '{gc.network.net_.__name__}'...")
    net = gc.network.net_(input_size=gc.network.input_size,
                          classify_size=len(classes)).to(device)
    if isinstance(gc.network.optim_, optim.SGD):
        optimizer = gc.network.optim_(net.parameters(),
                                      lr=gc.network.lr,
                                      momentum=gc.network.momentum)
    else:
        optimizer = gc.network.optim_(net.parameters(), lr=gc.network.lr)

    model = tutils.Model(
        net,
        optimizer=optimizer,
        criterion=nn.CrossEntropyLoss(),
        device=device,
        scaler=amp.GradScaler(enabled=gc.network.amp),
    )
    del net, optimizer

    # logfile
    if gc.option.is_save_log:
        gc.logfile = utils.LogFile(gc.path.log.joinpath("log.txt"),
                                   stdout=False)
        gc.ratefile = utils.LogFile(gc.path.log.joinpath("rate.csv"),
                                    stdout=False)

        classes_ = ",".join(classes)
        gc.ratefile.writeline(
            f"epoch,known,{classes_},avg,,unknown,{classes_},avg")
        gc.ratefile.flush()

    # netword difinition
    impl.show_network_difinition(gc,
                                 model,
                                 dataset,
                                 stdout=gc.option.is_show_network_difinition)

    # grad cam
    gcam_schedule = (utils.create_schedule(gc.network.epoch, gc.gradcam.cycle)
                     if gc.gradcam.enabled else [False] * gc.network.epoch)
    gcam = ExecuteGradCAM(
        classes,
        input_size=gc.network.input_size,
        target_layer=gc.gradcam.layer,
        device=device,
        schedule=gcam_schedule,
        is_gradcam=gc.gradcam.enabled,
    )

    # mkdir for gradcam
    phases = ["known", "unknown"]
    mkdir_options = {"parents": True, "exist_ok": True}
    for i, flag in enumerate(gcam_schedule):
        if not flag:
            continue
        ep_str = f"epoch{i+1}"
        for phase, cls in product(phases, classes):
            gc.path.gradcam.joinpath(f"{phase}_mistaken", ep_str,
                                     cls).mkdir(**mkdir_options)
        if not gc.gradcam.only_mistaken:
            for phase, cls in product(phases, classes):
                gc.path.gradcam.joinpath(f"{phase}_correct", ep_str,
                                         cls).mkdir(**mkdir_options)

    # progress bar
    pbar = utils.MyProgressBar(persist=True,
                               logfile=gc.logfile,
                               disable=gc.option.is_show_batch_result)

    # dummy functions
    exec_gcam_fn = fns.execute_gradcam if gc.gradcam.enabled else fns.dummy_execute_gradcam
    save_img_fn = (fns.save_mistaken_image if gc.option.is_save_mistaken_pred
                   else fns.dummy_save_mistaken_image)
    exec_softmax_fn = (fns.execute_softmax if gc.option.is_save_softmax else
                       fns.dummy_execute_softmax)

    def train_step(engine: Engine, batch: T._batch_path) -> float:
        return impl.train_step(
            engine,
            tutils.MiniBatch(batch),
            model,
            gc.network.subdivisions,
            gc,
            pbar,
            use_amp=gc.network.amp,
            save_img_fn=save_img_fn,
            non_blocking=True,
        )

    # trainer
    trainer = Engine(train_step)
    pbar.attach(trainer, metric_names="all")

    # tensorboard logger
    tb_logger = (SummaryWriter(
        log_dir=str(Path(gc.path.tb_log_dir, gc.filename_base)))
                 if gc.option.log_tensorboard else None)

    # schedule
    valid_schedule = utils.create_schedule(gc.network.epoch,
                                           gc.network.valid_cycle)
    save_schedule = utils.create_schedule(gc.network.epoch,
                                          gc.network.save_cycle)
    save_schedule[-1] = gc.network.is_save_final_model  # depends on config.

    save_cm_fn = fns.save_cm_image if gc.option.is_save_cm else fns.dummy_save_cm_image

    def validate_model(engine: Engine, collect_list: List[Tuple[DataLoader,
                                                                str]]) -> None:
        epoch = engine.state.epoch
        # do not validate.
        if not (valid_schedule[epoch - 1] or gcam_schedule[epoch - 1]):
            return

        impl.validate_model(
            engine,
            collect_list,
            gc,
            pbar,
            classes,
            gcam=gcam,
            model=model,
            valid_schedule=valid_schedule,
            tb_logger=tb_logger,
            exec_gcam_fn=exec_gcam_fn,
            exec_softmax_fn=exec_softmax_fn,
            save_cm_fn=save_cm_fn,
            non_blocking=True,
        )

    def save_model(engine: Engine) -> None:
        epoch = engine.state.epoch
        if save_schedule[epoch - 1]:
            impl.save_model(model, classes, gc, epoch)

    # validate / save
    trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        validate_model,
        [(known_loader, State.KNOWN), (unknown_loader, State.UNKNOWN)],  # args
    )
    trainer.add_event_handler(Events.EPOCH_COMPLETED, save_model)

    # kick everything off
    trainer.run(train_loader, max_epochs=gc.network.epoch)

    # close file
    if gc.option.log_tensorboard:
        tb_logger.close()
    if gc.option.is_save_log:
        gc.logfile.close()
        gc.ratefile.close()
Beispiel #15
0
def main(args):
    result_dir_path = Path(args.result_dir)
    result_dir_path.mkdir(parents=True, exist_ok=True)

    with Path(args.setting).open("r") as f:
        setting = json.load(f)
    pprint.pprint(setting)

    if args.g >= 0 and torch.cuda.is_available():
        device = torch.device(f"cuda:{args.g:d}")
        print(f"GPU mode: {args.g:d}")
    else:
        device = torch.device("cpu")
        print("CPU mode")

    mnist_neg = get_mnist_num(set(setting["label"]["neg"]))
    neg_loader = DataLoader(mnist_neg,
                            batch_size=setting["iterator"]["batch_size"])

    generator = get_generator().to(device)
    discriminator = get_discriminator().to(device)
    opt_g = torch.optim.Adam(
        generator.parameters(),
        lr=setting["optimizer"]["alpha"],
        betas=(setting["optimizer"]["beta1"], setting["optimizer"]["beta2"]),
        weight_decay=setting["regularization"]["weight_decay"])
    opt_d = torch.optim.Adam(
        discriminator.parameters(),
        lr=setting["optimizer"]["alpha"],
        betas=(setting["optimizer"]["beta1"], setting["optimizer"]["beta2"]),
        weight_decay=setting["regularization"]["weight_decay"])

    trainer = Engine(
        GANTrainer(generator,
                   discriminator,
                   opt_g,
                   opt_d,
                   device=device,
                   **setting["updater"]))

    # テスト用
    test_neg = get_mnist_num(set(setting["label"]["neg"]), train=False)
    test_neg_loader = DataLoader(test_neg, setting["iterator"]["batch_size"])
    test_pos = get_mnist_num(set(setting["label"]["pos"]), train=False)
    test_pos_loader = DataLoader(test_pos, setting["iterator"]["batch_size"])
    detector = Detector(generator, discriminator,
                        setting["updater"]["noise_std"], device).to(device)

    log_dict = {}
    evaluator = evaluate_accuracy(log_dict, detector, test_neg_loader,
                                  test_pos_loader, device)
    plotter = plot_metrics(log_dict, ["accuracy", "precision", "recall", "f"],
                           "iteration", result_dir_path / "metrics.pdf")
    printer = print_logs(log_dict,
                         ["iteration", "accuracy", "precision", "recall", "f"])
    img_saver = save_img(generator, test_pos, test_neg,
                         result_dir_path / "images",
                         setting["updater"]["noise_std"], device)

    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000),
                              evaluator)
    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), plotter)
    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), printer)
    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000),
                              img_saver)

    # 指定されたiterationで終了
    trainer.add_event_handler(
        Events.ITERATION_COMPLETED(once=setting["iteration"]),
        lambda engine: engine.terminate())
    trainer.run(neg_loader, max_epochs=10**10)
class Trainer:
    """ Abstract Trainer class.

    Helper class to support a ignite training process. Call run to start training. The main tasks are:
        - init visdom
        - set seed
        - log model architecture and parameters to file or console
        - limit train / valid samples in debug mode
        - split train data into train and validation
        - load model if required
        - init train and validate ignite engines
        - sets main metrics (both iter and epoch): loss and acc
        - add default events: model saving (each epoch), early stopping, log training progress
        - calls the validate engine after each training epoch, which runs one epoch.

    When extending this class, implement the following functions:
        - _train_function: executes a training step. It takes the ignite engine, this class and the current batch as
        arguments. Should return a dict with keys:
            - 'loss': metric of this class
            - 'acc': metric of this class
            - any key that is expected by the custom events of the child class
        - _validate_function: same as _train_function, but for validate

    Optionally extend:
        - _add_custom_events: function in which additional events can be added to the training process

    Args:
        model (_Net): model/network to be trained.
        loss (_Loss): loss of the model
        optimizer (Optimizer): optimizer used in gradient update
        dataset (Dataset): dataset of torch.Dataset class
        conf (Namespace): configuration obtained using configurations.general_confs.get_conf
    """
    def __init__(self, model, loss, optimizer, data_train, data_test, conf):

        self.model = model
        self.loss = loss
        self.optimizer = optimizer
        self.conf = conf
        self.device = get_device()

        self._log = get_logger(__name__)

        self.vis = self._init_visdom()

        # print number of parameters in model
        num_parameters = np.sum(
            [np.prod(list(p.shape)) for p in model.parameters()])
        self._log.info("Number of parameters model: {}".format(num_parameters))
        self._log.info("Model architecture: \n" + str(model))

        # init data sets
        kwargs = {}
        if self.device == "cuda":
            cuda_kwargs = {"pin_memory": True, "num_workers": 0}
            kwargs = {**cuda_kwargs}
        else:
            cuda_kwargs = {}
        if conf.debug:
            kwargs["train_max"] = 4
            kwargs["valid_max"] = 4
            kwargs["num_workers"] = 1
        if conf.seed:
            kwargs["seed"] = conf.seed
        self.train_loader, self.val_loader = get_train_valid_data(
            data_train,
            valid_size=conf.valid_size,
            batch_size=conf.batch_size,
            drop_last=conf.drop_last,
            shuffle=conf.shuffle,
            **kwargs)

        test_debug_sampler = SequentialSampler(list(range(
            3 * conf.batch_size))) if conf.debug else None
        self.test_loader = torch.utils.data.DataLoader(
            data_test,
            batch_size=conf.batch_size,
            drop_last=conf.drop_last,
            sampler=test_debug_sampler,
            **cuda_kwargs)

        # model to cuda if device is gpu
        model.to(self.device)

        # optimize cuda
        torch.backends.cudnn.benchmark = conf.cudnn_benchmark

        # load model
        if conf.load_model:
            if os.path.isfile(conf.model_load_path):
                if torch.cuda.is_available():
                    model = torch.load(conf.model_load_path)
                else:
                    model = torch.load(
                        conf.model_load_path,
                        map_location=lambda storage, loc: storage)
                self._log.info(f"Succesfully loaded {conf.load_name}")
            else:
                raise FileNotFoundError(
                    f"Could not found {conf.model_load_path}. Fix path or set load_model to False."
                )

        # init an ignite engine for each data set
        self.train_engine = Engine(self._train_function)
        self.valid_engine = Engine(self._valid_function)
        self.test_engine = Engine(self._test_function)

        # add train metrics
        ValueIterMetric(lambda x: x["loss"]).attach(
            self.train_engine, "batch_loss")  # for plot and progress log
        ValueIterMetric(lambda x: x["acc"]).attach(
            self.train_engine, "batch_acc")  # for plot and progress log

        # add visdom plot for the training loss
        training_loss_plot = VisIterPlotter(self.vis, "batch_loss", "Loss",
                                            "Training Batch Loss",
                                            self.conf.model_name)
        self.train_engine.add_event_handler(Events.ITERATION_COMPLETED,
                                            training_loss_plot)

        # add visdom plot for the training accuracy
        training_acc_plot = VisIterPlotter(self.vis, "batch_acc", "Acc",
                                           "Training Batch Acc",
                                           self.conf.model_name)
        self.train_engine.add_event_handler(Events.ITERATION_COMPLETED,
                                            training_acc_plot)

        # add logs handlers, requires the batch_loss and batch_acc metrics
        self.train_engine.add_event_handler(Events.ITERATION_COMPLETED,
                                            LogTrainProgressHandler())

        # add metrics
        ValueEpochMetric(lambda x: x["acc"]).attach(
            self.valid_engine, "acc")  # for plot and logging
        ValueEpochMetric(lambda x: x["loss"]).attach(
            self.valid_engine, "loss")  # for plot, logging and early stopping
        ValueEpochMetric(lambda x: x["acc"]).attach(self.test_engine,
                                                    "acc")  # for plot

        # add validation acc logger
        self.valid_engine.add_event_handler(
            Events.EPOCH_COMPLETED,
            LogEpochMetricHandler('Validation set: {:.4f}', "acc"))

        # print end of testing
        self.test_engine.add_event_handler(
            Events.EPOCH_COMPLETED, lambda _: self._log.info("Done testing"))

        # saves models
        if conf.save_trained:
            save_path = f"{conf.exp_path}/{conf.trained_model_path}"
            save_handler = ModelCheckpoint(
                save_path,
                conf.model_name,
                score_function=lambda engine: engine.state.metrics["acc"],
                n_saved=conf.n_saved,
                require_empty=False)
            self.valid_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                                save_handler, {'': model})

        # valid acc visdom plot
        acc_valid_plot = VisEpochPlotter(vis=self.vis,
                                         metric="acc",
                                         y_label="acc",
                                         title="Valid Accuracy",
                                         env_name=self.conf.model_name)
        self.valid_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                            acc_valid_plot)

        # test acc visdom plot
        acc_test_plot = VisEpochPlotter(vis=self.vis,
                                        metric="acc",
                                        y_label="acc",
                                        title="Test Accuracy",
                                        env_name=self.conf.model_name)
        self.test_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           acc_test_plot)

        # print ms per training example
        if self.conf.print_time:
            TimeMetric(lambda x: x["time"]).attach(self.train_engine, "time")
            self.train_engine.add_event_handler(
                Events.EPOCH_COMPLETED,
                LogEpochMetricHandler('Time per example: {:.6f} ms', "time"))

        # save test acc of the best validation epoch to file
        if self.conf.save_best:

            # Add score handler for the default inference: on valid and test the same sparsity as during training
            best_score_handler = SaveBestScore(
                score_valid_func=lambda engine: engine.state.metrics["acc"],
                score_test_func=lambda engine: engine.state.metrics["acc"],
                start_epoch=model.epoch,
                max_train_epochs=self.conf.epochs,
                model_name=self.conf.model_name,
                score_file_name=self.conf.score_file_name,
                root_path=self.conf.exp_path)
            self.valid_engine.add_event_handler(
                Events.EPOCH_COMPLETED, best_score_handler.update_valid)
            self.test_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                               best_score_handler.update_test)

        # add events custom events of the child class
        self._add_custom_events()

        # add early stopping, use total loss over epoch, stop if no improvement: higher score = better
        if conf.early_stop:
            early_stop_handler = EarlyStopping(
                patience=1,
                score_function=lambda engine: -engine.state.metrics["loss"],
                trainer=self.train_engine)
            self.valid_engine.add_event_handler(Events.COMPLETED,
                                                early_stop_handler)

        # set epoch in state of train_engine to model epoch at start to resume training for loaded model.
        # Note: new models have epoch = 0.
        @self.train_engine.on(Events.STARTED)
        def update_epoch(engine):
            engine.state.epoch = model.epoch

        # update epoch of the model, to make sure the is correct of resuming training
        @self.train_engine.on(Events.EPOCH_COMPLETED)
        def update_model_epoch(_):
            model.epoch += 1

        # makes sure eval_engine is started after train epoch, should be after all custom train_engine epoch_completed
        # events
        @self.train_engine.on(Events.EPOCH_COMPLETED)
        def call_valid(_):
            self.valid_engine.run(self.val_loader,
                                  self.train_engine.state.epoch)

        @self.train_engine.on(Events.ITERATION_COMPLETED)
        def check_nan(_):
            assert all([torch.isnan(p).nonzero().shape == torch.Size([0]) for p in model.parameters()]), \
                "Parameters contain NaNs. Occurred in this iteration."

        # makes sure test_engine is started after train epoch, should be after all custom valid_engine epoch_completed
        # events
        @self.valid_engine.on(Events.EPOCH_COMPLETED)
        def call_test(_):
            self.test_engine.run(self.test_loader,
                                 self.train_engine.state.epoch)

        # make that epoch in valid_engine and test_engine gives correct epoch (same train_engine was during run),
        # but makes sure only runs once
        @self.valid_engine.on(Events.STARTED)
        @self.test_engine.on(Events.STARTED)
        def set_train_epoch(engine):
            engine.state.epoch = self.train_engine.state.epoch - 1

        # Save the visdom environment
        @self.test_engine.on(Events.EPOCH_COMPLETED)
        def save_visdom_env(_):
            if isinstance(self.vis, visdom.Visdom):
                self.vis.save([self.conf.model_name])

    def _init_visdom(self):

        if self.conf.use_visdom:

            # start visdom if in conf
            if self.conf.start_visdom:

                # create visdom enviroment path if not exists
                if not os.path.exists(self.conf.exp_path):
                    os.makedirs(self.conf.exp_path)

                with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                    port = 8097
                    while s.connect_ex(('localhost', port)) == 0:
                        port += 1
                        if port == 8999:
                            break

                proc = Popen([
                    f"{sys.executable}", "-m", "visdom.server", "-env_path",
                    self.conf.exp_path, "-port",
                    str(port), "-logging_level", "50"
                ])
                time.sleep(1)

                vis = visdom.Visdom()

                retries = 0
                while (not vis.check_connection()) and retries < 10:
                    retries += 1
                    time.sleep(1)

                if not vis.check_connection():
                    raise RuntimeError("Could not start Visdom")

            # if use existing connection
            else:
                vis = visdom.Visdom()

                if vis.check_connection():
                    self._log.info("Use existing Visdom connection")

                # if no connection and not start
                else:
                    raise RuntimeError(
                        "Start visdom manually or set start_visdom to True")
        else:
            vis = None

        return vis

    def run(self):
        """ Start the training process. """
        self.train_engine.run(self.train_loader, max_epochs=self.conf.epochs)

    def _add_custom_events(self):
        pass

    def _train_function(self, engine, batch):
        raise NotImplementedError(
            "Please implement abstract function _train_function.")

    def _valid_function(self, engine, batch):
        raise NotImplementedError(
            "Please implement abstract function _valid_function.")

    def _test_function(self, engine, batch):
        raise NotImplementedError(
            "Please implement abstract function _test_function.")
Beispiel #17
0
def main(args):
    if args.g >= 0 and torch.cuda.is_available():
        device = torch.device(f"cuda:{args.g:d}")
        print(f"GPU mode: {args.g:d}")
    else:
        device = torch.device("cpu")
        print("CPU mode")

    result_dir = Path(args.result_dir)

    # MNISTデータ取得
    mnist_train = MNIST(root=".",
                        download=True,
                        train=True,
                        transform=lambda x: np.expand_dims(
                            np.asarray(x, dtype=np.float32), 0) / 255)
    mnist_loader = DataLoader(mnist_train, args.batchsize)
    mnist_loader = InfiniteDataLoader(mnist_loader)

    generator = get_generator(Z_DIM).to(device)
    critic = get_critic().to(device)

    opt_g = Adam(generator.parameters(), args.alpha, (args.beta1, args.beta2))
    opt_c = Adam(critic.parameters(), args.alpha, (args.beta1, args.beta2))

    trainer = Engine(
        WGANTrainer(mnist_loader, generator, critic, opt_g, opt_c, args.n_cri,
                    args.gp_lam, device))

    log_dict = {}
    accumulator = MetricsAccumulator(["generator_loss", "critic_loss"])
    trainer.add_event_handler(Events.ITERATION_COMPLETED, accumulator)
    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=500),
                              record_metrics(log_dict, accumulator))
    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=500),
                              print_metrics(log_dict, accumulator.keys))
    trainer.add_event_handler(
        Events.ITERATION_COMPLETED(every=500),
        plot_metrics(log_dict, "iteration", accumulator.keys,
                     result_dir / "metrics.pdf"))
    trainer.add_event_handler(
        Events.ITERATION_COMPLETED(every=500),
        save_img(generator, result_dir / "generated_samples", device))

    # 指定されたイテレーション数で終了させる
    trainer.add_event_handler(Events.ITERATION_COMPLETED(once=args.iteration),
                              lambda engine: engine.terminate())

    trainer.run(mnist_loader, max_epochs=10**10)
Beispiel #18
0
    def train(self, config, **kwargs):
        """Trains a model on the given configurations.
        :param config: A training configuration. Note that all parameters in the config can also be manually adjusted with --ARG=VALUE
        :param **kwargs: parameters to overwrite yaml config
        """
        from pycocoevalcap.cider.cider import Cider

        config_parameters = train_util.parse_config_or_kwargs(config, **kwargs)
        config_parameters["seed"] = self.seed
        outputdir = os.path.join(
            config_parameters["outputpath"], config_parameters["model"],
            "{}_{}".format(
                datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%m'),
                uuid.uuid1().hex))

        # Early init because of creating dir
        checkpoint_handler = ModelCheckpoint(
            outputdir,
            "run",
            n_saved=1,
            require_empty=False,
            create_dir=True,
            score_function=lambda engine: engine.state.metrics["score"],
            score_name="score")

        logger = train_util.genlogger(os.path.join(outputdir, "train.log"))
        # print passed config parameters
        logger.info("Storing files in: {}".format(outputdir))
        train_util.pprint_dict(config_parameters, logger.info)

        zh = config_parameters["zh"]
        vocabulary = torch.load(config_parameters["vocab_file"])
        train_loader, cv_loader, info = self._get_dataloaders(
            config_parameters, vocabulary)
        config_parameters["inputdim"] = info["inputdim"]
        cv_key2refs = info["cv_key2refs"]
        logger.info("<== Estimating Scaler ({}) ==>".format(
            info["scaler"].__class__.__name__))
        logger.info("Feature: {} Input dimension: {} Vocab Size: {}".format(
            config_parameters["feature_file"], info["inputdim"],
            len(vocabulary)))

        model = self._get_model(config_parameters, len(vocabulary))
        if "pretrained_word_embedding" in config_parameters:
            embeddings = np.load(
                config_parameters["pretrained_word_embedding"])
            model.load_word_embeddings(
                embeddings,
                tune=config_parameters["tune_word_embedding"],
                projection=True)
        model = model.to(self.device)
        train_util.pprint_dict(model, logger.info, formatter="pretty")
        optimizer = getattr(torch.optim, config_parameters["optimizer"])(
            model.parameters(), **config_parameters["optimizer_args"])
        train_util.pprint_dict(optimizer, logger.info, formatter="pretty")

        criterion = torch.nn.CrossEntropyLoss().to(self.device)
        crtrn_imprvd = train_util.criterion_improver(
            config_parameters['improvecriterion'])

        def _train_batch(engine, batch):
            model.train()
            with torch.enable_grad():
                optimizer.zero_grad()
                output = self._forward(model, batch, "train")
                loss = criterion(output["packed_logits"],
                                 output["targets"]).to(self.device)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                optimizer.step()
                output["loss"] = loss.item()
                return output

        trainer = Engine(_train_batch)
        RunningAverage(output_transform=lambda x: x["loss"]).attach(
            trainer, "running_loss")
        pbar = ProgressBar(persist=False, ascii=True, ncols=100)
        pbar.attach(trainer, ["running_loss"])

        key2pred = {}

        def _inference(engine, batch):
            model.eval()
            keys = batch[2]
            with torch.no_grad():
                output = self._forward(model, batch, "validation")
                seqs = output["seqs"].cpu().numpy()
                for (idx, seq) in enumerate(seqs):
                    if keys[idx] in key2pred:
                        continue
                    candidate = self._convert_idx2sentence(seq, vocabulary, zh)
                    key2pred[keys[idx]] = [
                        candidate,
                    ]
                return output

        metrics = {
            "loss":
            Loss(criterion,
                 output_transform=lambda x: (x["packed_logits"], x["targets"]))
        }

        evaluator = Engine(_inference)

        def eval_cv(engine, key2pred, key2refs):
            scorer = Cider(zh=zh)
            score, scores = scorer.compute_score(key2refs, key2pred)
            engine.state.metrics["score"] = score
            key2pred.clear()

        evaluator.add_event_handler(Events.EPOCH_COMPLETED, eval_cv, key2pred,
                                    cv_key2refs)

        for name, metric in metrics.items():
            metric.attach(evaluator, name)

        trainer.add_event_handler(Events.EPOCH_COMPLETED,
                                  train_util.log_results, evaluator, cv_loader,
                                  logger.info, ["loss", "score"])

        evaluator.add_event_handler(
            Events.EPOCH_COMPLETED, train_util.save_model_on_improved,
            crtrn_imprvd, "score", {
                "model": model.state_dict(),
                "config": config_parameters,
                "scaler": info["scaler"]
            }, os.path.join(outputdir, "saved.pth"))

        scheduler = getattr(torch.optim.lr_scheduler,
                            config_parameters["scheduler"])(
                                optimizer,
                                **config_parameters["scheduler_args"])
        evaluator.add_event_handler(Events.EPOCH_COMPLETED,
                                    train_util.update_lr, scheduler, "score")

        evaluator.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler,
                                    {
                                        "model": model,
                                    })

        trainer.run(train_loader, max_epochs=config_parameters["epochs"])
        return outputdir
def get_handlers(
    config: Any,
    model: Module,
    trainer: Engine,
    evaluator: Engine,
    metric_name: str,
    es_metric_name: str,
    train_sampler: Optional[DistributedSampler] = None,
    to_save: Optional[Mapping] = None,
    lr_scheduler: Optional[LRScheduler] = None,
    output_names: Optional[Iterable[str]] = None,
    **kwargs: Any,
) -> Union[Tuple[Checkpoint, EarlyStopping, Timer], Tuple[None, None, None]]:
    """Get best model, earlystopping, timer handlers.

    Parameters
    ----------
    config
        Config object for setting up handlers

    `config` has to contain
    - `output_dir`: output path to indicate where to_save objects are stored
    - `save_every_iters`: saving iteration interval
    - `n_saved`: number of best models to store
    - `log_every_iters`: logging interval for iteration progress bar and `GpuInfo` if true
    - `with_pbars`: show two progress bars
    - `with_pbar_on_iters`: show iteration-wise progress bar
    - `stop_on_nan`: Stop the training if engine output contains NaN/inf values
    - `clear_cuda_cache`: clear cuda cache every end of epoch
    - `with_gpu_stats`: show GPU information: used memory percentage, gpu utilization percentage values
    - `patience`: number of events to wait if no improvement and then stop the training
    - `limit_sec`: maximum time before training terminates in seconds

    model
        best model to save
    trainer
        the engine used for training
    evaluator
        the engine used for evaluation
    metric_name
        evaluation metric to save the best model
    es_metric_name
        evaluation metric to early stop the model
    train_sampler
        distributed training sampler to call `set_epoch`
    to_save
        objects to save during training
    lr_scheduler
        learning rate scheduler as native torch LRScheduler or ignite’s parameter scheduler
    output_names
        list of names associated with `trainer`'s process_function output dictionary
    kwargs
        keyword arguments passed to Checkpoint handler

    Returns
    -------
    best_model_handler, es_handler, timer_handler
    """

    best_model_handler, es_handler, timer_handler = None, None, None

    # https://pytorch.org/ignite/contrib/engines.html#ignite.contrib.engines.common.setup_common_training_handlers
    # kwargs can be passed to save the model based on training stats
    # like score_name, score_function
    common.setup_common_training_handlers(
        trainer=trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        lr_scheduler=lr_scheduler,
        output_names=output_names,
        output_path=config.output_dir / 'checkpoints',
        save_every_iters=config.save_every_iters,
        n_saved=config.n_saved,
        log_every_iters=config.log_every_iters,
        with_pbars=config.with_pbars,
        with_pbar_on_iters=config.with_pbar_on_iters,
        stop_on_nan=config.stop_on_nan,
        clear_cuda_cache=config.clear_cuda_cache,
        with_gpu_stats=config.with_gpu_stats,
        **kwargs,
    )
    {% if save_best_model_by_val_score %}

    # https://pytorch.org/ignite/contrib/engines.html#ignite.contrib.engines.common.save_best_model_by_val_score
    best_model_handler = common.save_best_model_by_val_score(
        output_path=config.output_dir / 'checkpoints',
        evaluator=evaluator,
        model=model,
        metric_name=metric_name,
        n_saved=config.n_saved,
        trainer=trainer,
        tag='eval',
    )
    {% endif %}
    {% if add_early_stopping_by_val_score %}

    # https://pytorch.org/ignite/contrib/engines.html#ignite.contrib.engines.common.add_early_stopping_by_val_score
    es_handler = common.add_early_stopping_by_val_score(
        patience=config.patience,
        evaluator=evaluator,
        trainer=trainer,
        metric_name=es_metric_name,
    )
    {% endif %}
    {% if setup_timer %}

    # https://pytorch.org/ignite/handlers.html#ignite.handlers.Timer
    # measure the average time to process a single batch of samples
    # Events for that are - ITERATION_STARTED and ITERATION_COMPLETED
    # you can replace with the events you want to measure
    timer_handler = Timer(average=True)
    timer_handler.attach(
        engine=trainer,
        start=Events.EPOCH_STARTED,
        resume=Events.ITERATION_STARTED,
        pause=Events.ITERATION_COMPLETED,
        step=Events.ITERATION_COMPLETED,
    )
    {% endif %}
    {% if setup_timelimit %}

    # training will terminate if training time exceed `limit_sec`.
    trainer.add_event_handler(
        Events.ITERATION_COMPLETED, TimeLimit(limit_sec=config.limit_sec)
    )
    {% endif %}
    return best_model_handler, es_handler, timer_handler