train(training_model, train_data, run_opts, lr_scheduler, range(1, run_opts.epoch + 1), optimizer, training_validation_func) # Validation and weight averaging runs on single process if not run_opts.use_popdist or run_opts.popdist_rank == 0: if run_opts.weight_avg_strategy != 'none': average_fn = weight_avg.create_average_fn(run_opts) weight_avg.average_model_weights(run_opts.checkpoint_path, average_fn, run_opts.weight_avg_N) if run_opts.validation_mode == "after": if run_opts.checkpoint_path == "": training_model.destroy() val_func = get_validation_function(run_opts, model) acc = val_func() result_dict = { "validation_epoch": run_opts.epoch, "validation_iteration": run_opts.logs_per_epoch * run_opts.epoch, "validation_accuracy": acc } utils.Logger.log_validate_results(result_dict) else: training_model.destroy() checkpoint_files = [ os.path.join(run_opts.checkpoint_path, file_name) for file_name in os.listdir(run_opts.checkpoint_path) if file_name.endswith(".pt") ] validate_checkpoints(checkpoint_files)
logging.basicConfig(format='%(asctime)s %(module)s - %(funcName)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S') logging.getLogger().setLevel(logging.INFO) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Restoring training run from a given checkpoint') parser.add_argument('--checkpoint-path', help="The path of the checkpoint file", required=True) args = parser.parse_args() checkpoint = torch.load(args.checkpoint_path) opts = checkpoint['opts'] logging.info("Loading the data") model_opts = create_model_opts(opts) train_data, test_data = get_data(opts, model_opts) logging.info("Restore the {0} model to epoch {1} on {2} dataset(Loss:{3}, train accuracy:{4})".format(opts.model, checkpoint["epoch"], opts.data, checkpoint["loss"], checkpoint["train_accuracy"])) model = models.get_model(opts, datasets_info[opts.data], pretrained=False) model.load_state_dict(checkpoint['model_state_dict']) model.train() optimizer, lr_scheduler = get_optimizer(opts, model) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) training_model = convert_to_ipu_model(model, opts, optimizer) train(training_model, train_data, opts, lr_scheduler, range(checkpoint["epoch"]+1, opts.epoch+1), optimizer) checkpoint_folder = os.path.dirname(os.path.realpath(args.checkpoint_path)) checkpoint_files = [os.path.join(checkpoint_folder, file_name) for file_name in os.listdir(checkpoint_folder)] validate_checkpoints(checkpoint_files, test_data=test_data)