Beispiel #1
0
    train(training_model, train_data, run_opts, lr_scheduler,
          range(1, run_opts.epoch + 1), optimizer, training_validation_func)

    # Validation and weight averaging runs on single process
    if not run_opts.use_popdist or run_opts.popdist_rank == 0:
        if run_opts.weight_avg_strategy != 'none':
            average_fn = weight_avg.create_average_fn(run_opts)
            weight_avg.average_model_weights(run_opts.checkpoint_path,
                                             average_fn, run_opts.weight_avg_N)

        if run_opts.validation_mode == "after":
            if run_opts.checkpoint_path == "":
                training_model.destroy()
                val_func = get_validation_function(run_opts, model)
                acc = val_func()
                result_dict = {
                    "validation_epoch": run_opts.epoch,
                    "validation_iteration":
                    run_opts.logs_per_epoch * run_opts.epoch,
                    "validation_accuracy": acc
                }
                utils.Logger.log_validate_results(result_dict)
            else:
                training_model.destroy()
                checkpoint_files = [
                    os.path.join(run_opts.checkpoint_path, file_name)
                    for file_name in os.listdir(run_opts.checkpoint_path)
                    if file_name.endswith(".pt")
                ]
                validate_checkpoints(checkpoint_files)
Beispiel #2
0
logging.basicConfig(format='%(asctime)s %(module)s - %(funcName)s: %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S')
logging.getLogger().setLevel(logging.INFO)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Restoring training run from a given checkpoint')
    parser.add_argument('--checkpoint-path', help="The path of the checkpoint file", required=True)
    args = parser.parse_args()
    checkpoint = torch.load(args.checkpoint_path)
    opts = checkpoint['opts']

    logging.info("Loading the data")
    model_opts = create_model_opts(opts)
    train_data, test_data = get_data(opts, model_opts)

    logging.info("Restore the {0} model to epoch {1} on {2} dataset(Loss:{3}, train accuracy:{4})".format(opts.model, checkpoint["epoch"], opts.data, checkpoint["loss"], checkpoint["train_accuracy"]))
    model = models.get_model(opts, datasets_info[opts.data], pretrained=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.train()

    optimizer, lr_scheduler = get_optimizer(opts, model)
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    training_model = convert_to_ipu_model(model, opts, optimizer)


    train(training_model, train_data, opts, lr_scheduler, range(checkpoint["epoch"]+1, opts.epoch+1), optimizer)

    checkpoint_folder = os.path.dirname(os.path.realpath(args.checkpoint_path))
    checkpoint_files = [os.path.join(checkpoint_folder, file_name) for file_name in os.listdir(checkpoint_folder)]
    validate_checkpoints(checkpoint_files, test_data=test_data)