def load_model(checkpoint_file): checkpoint = torch.load(checkpoint_file) args = checkpoint['args'] model = models.get_model(args, datasets.datasets_info[args.data], pretrained=False) models.load_model_state_dict(model, checkpoint['model_state_dict']) model.double() return model
def load_checkpoint_weights(inference_model, file_path): checkpoint = torch.load(file_path) args = checkpoint['args'] logging.info( f"Restore the {args.model} model to epoch {checkpoint['epoch']} on {args.data} dataset(Train loss:{checkpoint['loss']}, train accuracy:{checkpoint['train_accuracy']}%)" ) models.load_model_state_dict(inference_model, checkpoint['model_state_dict']) inference_model.copyWeightsToDevice()
def validate_checkpoints(checkpoint_list, test_data=None): checkpoint = torch.load(checkpoint_list[0]) args = checkpoint['args'] utils.Logger.setup_logging_folder(args) # make sure the order is ascending def ckpt_key(ckpt): return int(ckpt.split('_')[-1].split('.')[0]) try: checkpoint_list = sorted(checkpoint_list, key=ckpt_key) except: logging.warn( "Checkpoint names are changed, which may cause inconsistent order in evaluation." ) # Validate in a single instance opts = create_validation_opts(args, use_popdist=False) args.use_popdist = False args.popdist_size = 1 if test_data is None: test_data = datasets.get_data(args, opts, train=False, async_dataloader=True, return_remaining=True) model = models.get_model(args, datasets.datasets_info[args.data], pretrained=False) model.eval() # Load the weights of the first checkpoint for the model models.load_model_state_dict(model, checkpoint['model_state_dict']) inference_model = poptorch.inferenceModel(model, opts) for checkpoint in checkpoint_list: if inference_model.isCompiled(): load_checkpoint_weights(inference_model, checkpoint) val_accuracy = test(inference_model, test_data) epoch_nr = torch.load(checkpoint)["epoch"] log_data = { "validation_epoch": epoch_nr, "validation_iteration": epoch_nr * len(test_data), "validation_accuracy": val_accuracy, } utils.Logger.log_validate_results(log_data)
def fine_tune(args): logging.info("Fine-tuning the model after half resolution training") args.half_res_training = False args.mixup_enabled = False args.cutmix_enabled = False args.optimizer = 'sgd' args.momentum = 0.0 args.warmup_epoch = 0 args.lr = args.fine_tune_lr args.lr_schedule = 'cosine' args.lr_scheduler_freq = 0 args.batch_size = args.fine_tune_batch_size args.gradient_accumulation = args.fine_tune_gradient_accumulation opts = create_training_opts(args) train_data = datasets.get_data(args, opts, train=True, fine_tuning=True, async_dataloader=True) model_fine_tune = models.get_model(args, datasets.datasets_info[args.data], pretrained=False, use_mixup=args.mixup_enabled, use_cutmix=args.cutmix_enabled) if not args.use_popdist or args.popdist_rank == 0: avg_checkpoint_file = os.path.join( args.checkpoint_path, f"{args.model}_{args.data}_{args.epoch}_averaged.pt") avg_checkpoint = torch.load(avg_checkpoint_file) models.load_model_state_dict(model_fine_tune, avg_checkpoint['model_state_dict']) if args.use_popdist: hvd.broadcast_parameters(models.get_model_state_dict(model_fine_tune), root_rank=0) model_fine_tune.train() nested_model = models.get_nested_model(model_fine_tune) # Freeze relevant parameters. for param_name, param in nested_model.named_parameters(): param_name = param_name.replace('.', '/') if param_name.startswith(args.fine_tune_first_trainable_layer): break logging.info(f"Freezing parameter {param_name}") param.requires_grad = False # Make relevant dropout and batch norm layers eval. for module_name, module in nested_model.named_modules(): module_name = module_name.replace('.', '/') if module_name.startswith(args.fine_tune_first_trainable_layer): break if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) or isinstance( module, torch.nn.modules.dropout._DropoutNd): logging.info(f"Setting module {module_name} to eval mode") module.eval() optimizer = get_optimizer(args, model_fine_tune) lr_scheduler = get_lr_scheduler(args, optimizer, len(train_data)) training_model = convert_to_ipu_model(model_fine_tune, args, optimizer) train(training_model, train_data, args, lr_scheduler, range(args.epoch + 1, args.epoch + 1 + args.fine_tune_epoch), optimizer) train_data.terminate() return model_fine_tune, training_model
opts = create_training_opts(args) train_data = datasets.get_data(args, opts, train=True, async_dataloader=True) logging.info( f"Restore the {args.model} model to epoch {checkpoint['epoch']} on {args.data} dataset(Loss:{checkpoint['loss']}, train accuracy:{checkpoint['train_accuracy']})" ) model = models.get_model(args, datasets.datasets_info[args.data], pretrained=False, use_mixup=args.mixup_enabled, use_cutmix=args.cutmix_enabled) models.load_model_state_dict(model, checkpoint['model_state_dict']) model.train() optimizer = get_optimizer(args, model) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) lr_scheduler = get_lr_scheduler(args, optimizer, len(train_data), start_epoch=checkpoint["epoch"]) training_model = convert_to_ipu_model(model, args, optimizer) if args.validation_mode == "during": training_validation_func = get_validation_function(args, model).func else: training_validation_func = None