def create_model_miniimagenet(args, config, architecture, ENCODER_CONFIG=None): # Create Model if args.model == "MAML" or args.model == "MetaSGD": # Meta SGD or MAML if config['is_meta_sgd']: model = MetaSGD(architecture, config['update_lr'], config['update_step'], is_regression=False) else: model = Meta(architecture, config['update_lr'], config['update_step'], is_regression=False) elif args.model == "LEO": model = LEO(config) elif args.model == "MLwM": model = MLwM(ENCODER_CONFIG, architecture, config['update_lr'], config['update_step'],\ is_regression=False) else: NotImplementedError return model
def create_model_embed_miniimagenet(args, config, architecture, ENCODER_CONFIG=None): if args.model == "MAML" or args.model == "MetaSGD": # Meta SGD or MAML if config['is_meta_sgd']: model = MetaSGD(architecture, config['update_lr'], config['update_step'], is_regression=False, is_image_feature=False) else: model = Meta(architecture, config['update_lr'], config['update_step'], is_regression=False, is_image_feature=False) elif args.model == "LEO": model = LEO(config) elif args.model == "SIB": model = SIB(args.n_way, config) elif args.model == "Prototypes_embedded": model = PrototypeNet_embedded(args.n_way, config) elif args.model == "MLwM": # MLwM with MAML or MetaSGD model = MLwM(ENCODER_CONFIG, architecture, config['update_lr'], config['update_step'],\ is_regression=False) else: NotImplementedError return model
def create_model_poseregression(args, config, architecture, ENCODER_CONFIG=None): # Create Model if args.model == "MAML" or args.model == "MetaSGD": # Meta SGD or MAML if config['is_meta_sgd']: model = MetaSGD(architecture, config['update_lr'], config['update_step'], is_regression=True) else: model = Meta(architecture, config['update_lr'], config['update_step'], is_regression=True) elif args.model == "MLwM": model = MLwM(config, architecture, config['update_lr'], config['update_step'],\ is_regression=True) elif args.model == "CNP": model = CNP(config, is_regression=True) else: NotImplementedError return model
encoded_img_size, is_regression=True) else: architecture = set_config(config['CONFIG_CONV_4_MAML'], args.n_way, config['img_size'], is_regression=True) # Create Model if args.model == "MAML": model = Meta(architecture, config['update_lr'], config['update_step'], is_regression=True) elif args.model == "MLwM": model = MLwM(ENCODER_CONFIG, architecture, config['update_lr'], config['update_step'],\ is_regression=True, is_kl_loss=True, beta_kl=config['beta_kl']) else: NotImplementedError # Train train(model, config, save_model_path) # load model path if args.model_save_root_dir == args.model_load_dir: load_model_path = latest_load_model_filepath(args) else: load_model_path = args.model_load_dir # Test test(model, load_model_path, save_model_path)