示例#1
0
def save_checkpoint(outpath, epoch, model_feature, model_source_classifier,
                    model_target_classifier, optimizer, lr_scheduler,
                    val_best_acc):

    check_point_params = {}
    if isinstance(model_feature, nn.DataParallel):
        check_point_params["model_feature"] = model_feature.module.state_dict()
    else:
        check_point_params["model_feature"] = model_feature.state_dict()

    if isinstance(model_source_classifier, nn.DataParallel):
        check_point_params[
            "model_source_classifier"] = model_source_classifier.module.state_dict(
            )
    else:
        check_point_params[
            "model_source_classifier"] = model_source_classifier.state_dict()

    if isinstance(model_target_classifier, nn.DataParallel):
        check_point_params[
            "model_target_classifier"] = model_target_classifier.module.state_dict(
            )
    else:
        check_point_params[
            "model_target_classifier"] = model_target_classifier.state_dict()

    check_point_params["val_best_acc"] = val_best_acc
    check_point_params["optimizer"] = optimizer
    check_point_params["lr_scheduler"] = lr_scheduler
    check_point_params['epoch'] = epoch

    output_path = os.path.join(outpath, "check_point")
    ensure_folder(output_path)
    filename = 'checkpoint.pth'
    torch.save(check_point_params, os.path.join(output_path, filename))
示例#2
0
def save_model(outpath, epoch, model_feature, model_target_classifier,
               val_best_acc, logger):
    check_point_params = {}

    if isinstance(model_feature, nn.DataParallel):
        check_point_params["model"] = model_feature.module.state_dict()
    else:
        check_point_params["model"] = model_feature.state_dict()

    if isinstance(model_target_classifier, nn.DataParallel):
        check_point_params["fc"] = model_target_classifier.module.state_dict()
    else:
        check_point_params["fc"] = model_target_classifier.state_dict()

    output_path = os.path.join(outpath, "check_point")
    ensure_folder(output_path)
    filename = 'model_{:03d}_acc{:.4f}.pth'.format(epoch, val_best_acc)
    torch.save(check_point_params, os.path.join(output_path, filename))
示例#3
0
def save_model(model, filename, model_folder):
    ensure_folder(model_folder)
    model_path = os.path.join(model_folder, filename)
    torch.save(model.state_dict(), model_path)