def train(model, config, save_model_path, initializer=torch.nn.init.xavier_normal_): # Create Model model = model # Parser args = parse_args() #device os.environ['CUDA_VISIBLE_DEVICES'] = args.device device = torch.device('cuda') # dataset pose_training_set = meta_pose_regression_dataset(args.n_way, args.k_shot_support, args.k_shot_query, \ args.data_path, types=args.datatypes) pose_valid_set = meta_pose_regression_dataset(args.n_way, args.k_shot_support, args.k_shot_query, \ args.data_path, mode='val', types=args.datatypes) # dataloader train_loader = DataLoader(pose_training_set, batch_size=args.task_size, shuffle=True) valid_loader = DataLoader(pose_valid_set, batch_size=args.task_size, shuffle=True) # Print length of a episode print("length of episode : ", len(train_loader)) if DEBUG: support_x, support_y, query_x, query_y = pose_training_set[0] print("=" * 25, "DEBUG", "=" * 25) print("support_x shape : ", support_x.shape) print("support_y shape : ", support_y.shape) print("query_x shape : ", query_x.shape) print("query_y shape : ", query_y.shape) # Set the optimizer optimizer = optim.Adam(model.parameters(), lr=config['meta_lr']) # Operator maml_operator = MAML_operator(model, device, train_loader, optimizer, args.epochs, save_model_path, valid_loader) maml_operator.train() # Save Model torch.save( model.state_dict(), os.path.join(save_model_path, "Model_{}.pt".format(args.datatypes))) print("=" * 20, "Save the model (After training)", "=" * 20) # Move saved files to the result folder remove_temp_files_and_move_directory(save_model_path, "/home/mgyukim/workspaces/result_MLwM", args.model, \ config['encoder_type'], config['beta_kl'], "poseregression", args.datatypes, args.description)
def test(model, config, load_model_path, save_model_path, initializer=torch.nn.init.xavier_normal_): # Create Model model = model # Parser args = parse_args() #device os.environ['CUDA_VISIBLE_DEVICES'] = args.device device = torch.device('cuda') # dataset miniimagenet_test_set = meta_miniImagenet_dataset(args.n_way, args.k_shot_support, args.k_shot_query, \ args.data_path, config['img_size'], mode='test', types=args.datatypes) test_loader = DataLoader(miniimagenet_test_set, batch_size=args.task_size, shuffle=True) if DEBUG: support_x, support_y, query_x, query_y = miniimagenet_test_set[0] print("=" * 25, "DEBUG", "=" * 25) print("support_x shape : ", support_x.shape) print("support_y shape : ", support_y.shape) print("query_x shape : ", query_x.shape) print("query_y shape : ", query_y.shape) # Load a model checkpoint = torch.load(load_model_path) model.load_state_dict(checkpoint) print("=" * 20, "Load the model : {}".format(load_model_path), "=" * 20) # Operator maml_operator = MAML_operator(model, device, test_loader, savedir=save_model_path) maml_operator.test()