Exemple #1
0
def main():
    parser = buildArgsParser()
    args = parser.parse_args()

    # Extract experiments hyperparameters
    hyperparams = dict(vars(args))

    # Remove hyperparams that should not be part of the hash
    del hyperparams['max_epoch']
    del hyperparams['keep']
    del hyperparams['force']
    del hyperparams['name']

    # Get/generate experiment name
    experiment_name = args.name
    if experiment_name is None:
        experiment_name = utils.generate_uid_from_string(repr(hyperparams))

    # Create experiment folder
    experiment_path = pjoin(".", "experiments", experiment_name)
    resuming = False
    if os.path.isdir(experiment_path) and not args.force:
        resuming = True
        print("### Resuming experiment ({0}). ###\n".format(experiment_name))
        # Check if provided hyperparams match those in the experiment folder
        hyperparams_loaded = utils.load_dict_from_json_file(pjoin(experiment_path, "hyperparams.json"))
        if hyperparams != hyperparams_loaded:
            print("{\n" + "\n".join(["{}: {}".format(k, hyperparams[k]) for k in sorted(hyperparams.keys())]) + "\n}")
            print("{\n" + "\n".join(["{}: {}".format(k, hyperparams_loaded[k]) for k in sorted(hyperparams_loaded.keys())]) + "\n}")
            print("The arguments provided are different than the one saved. Use --force if you are certain.\nQuitting.")
            sys.exit(1)
    else:
        if os.path.isdir(experiment_path):
            shutil.rmtree(experiment_path)

        os.makedirs(experiment_path)
        utils.save_dict_to_json_file(pjoin(experiment_path, "hyperparams.json"), hyperparams)

    with Timer("Loading dataset"):
        trainset, validset, testset = datasets.load(args.dataset)

        image_shape = (28, 28)
        nb_channels = 1 + (args.use_mask_as_input is True)

        batch_scheduler = MiniBatchSchedulerWithAutoregressiveMask(trainset, args.batch_size,
                                                                   use_mask_as_input=args.use_mask_as_input,
                                                                   seed=args.ordering_seed)
        print("{} updates per epoch.".format(len(batch_scheduler)))

    with Timer("Building model"):
        if args.use_lasagne:
            if args.with_residual:
                model = DeepConvNadeWithResidualUsingLasagne(image_shape=image_shape,
                                                             nb_channels=nb_channels,
                                                             convnet_blueprint=args.convnet_blueprint,
                                                             fullnet_blueprint=args.fullnet_blueprint,
                                                             hidden_activation=args.hidden_activation,
                                                             use_mask_as_input=args.use_mask_as_input)
            else:
                model = DeepConvNadeUsingLasagne(image_shape=image_shape,
                                                 nb_channels=nb_channels,
                                                 convnet_blueprint=args.convnet_blueprint,
                                                 fullnet_blueprint=args.fullnet_blueprint,
                                                 hidden_activation=args.hidden_activation,
                                                 use_mask_as_input=args.use_mask_as_input,
                                                 use_batch_norm=args.batch_norm)

        elif args.with_residual:
            model = DeepConvNADEWithResidual(image_shape=image_shape,
                                             nb_channels=nb_channels,
                                             convnet_blueprint=args.convnet_blueprint,
                                             fullnet_blueprint=args.fullnet_blueprint,
                                             hidden_activation=args.hidden_activation,
                                             use_mask_as_input=args.use_mask_as_input)

        else:
            builder = DeepConvNADEBuilder(image_shape=image_shape,
                                          nb_channels=nb_channels,
                                          hidden_activation=args.hidden_activation,
                                          use_mask_as_input=args.use_mask_as_input)

            if args.blueprints_seed is not None:
                convnet_blueprint, fullnet_blueprint = generate_blueprints(args.blueprint_seed, image_shape[0])
                builder.build_convnet_from_blueprint(convnet_blueprint)
                builder.build_fullnet_from_blueprint(fullnet_blueprint)
            else:
                if args.convnet_blueprint is not None:
                    builder.build_convnet_from_blueprint(args.convnet_blueprint)

                if args.fullnet_blueprint is not None:
                    builder.build_fullnet_from_blueprint(args.fullnet_blueprint)

            model = builder.build()
            # print(str(model.convnet))
            # print(str(model.fullnet))

        model.initialize(weigths_initializer_factory(args.weights_initialization,
                                                     seed=args.initialization_seed))
        print(str(model))

    with Timer("Building optimizer"):
        loss = BinaryCrossEntropyEstimateWithAutoRegressiveMask(model, trainset)
        optimizer = optimizer_factory(hyperparams, loss)

    with Timer("Building trainer"):
        trainer = Trainer(optimizer, batch_scheduler)

        if args.max_epoch is not None:
            trainer.append_task(stopping_criteria.MaxEpochStopping(args.max_epoch))

        # Print time for one epoch
        trainer.append_task(tasks.PrintEpochDuration())
        trainer.append_task(tasks.PrintTrainingDuration())

        # Log training error
        loss_monitor = views.MonitorVariable(loss.loss)
        avg_loss = tasks.AveragePerEpoch(loss_monitor)
        accum = tasks.Accumulator(loss_monitor)
        logger = tasks.Logger(loss_monitor, avg_loss)
        trainer.append_task(logger, avg_loss, accum)

        # Print average training loss.
        trainer.append_task(tasks.Print("Avg. training loss:     : {}", avg_loss))

        # Print NLL mean/stderror.
        model.deterministic = True  # For batch normalization, see https://github.com/Lasagne/Lasagne/blob/master/lasagne/layers/normalization.py#L198
        nll = views.LossView(loss=BinaryCrossEntropyEstimateWithAutoRegressiveMask(model, validset),
                             batch_scheduler=MiniBatchSchedulerWithAutoregressiveMask(validset, batch_size=0.1*len(validset),
                                                                                      use_mask_as_input=args.use_mask_as_input,
                                                                                      keep_mask=True,
                                                                                      seed=args.ordering_seed+1))
        # trainer.append_task(tasks.Print("Validset - NLL          : {0:.2f} ± {1:.2f}", nll.mean, nll.stderror, each_k_update=100))
        trainer.append_task(tasks.Print("Validset - NLL          : {0:.2f} ± {1:.2f}", nll.mean, nll.stderror))

        # direction_norm = views.MonitorVariable(T.sqrt(sum(map(lambda d: T.sqr(d).sum(), loss.gradients.values()))))
        # trainer.append_task(tasks.Print("||d|| : {0:.4f}", direction_norm, each_k_update=50))

        # Save training progression
        def save_model(*args):
            trainer.save(experiment_path)

        trainer.append_task(stopping_criteria.EarlyStopping(nll.mean, lookahead=args.lookahead, eps=args.lookahead_eps, callback=save_model))

        trainer.build_theano_graph()

    if resuming:
        with Timer("Loading"):
            trainer.load(experiment_path)

    with Timer("Training"):
        trainer.train()

    trainer.save(experiment_path)
    model.save(experiment_path)