Пример #1
0
def train(model,
          config,
          save_model_path,
          initializer=torch.nn.init.xavier_normal_):
    # Create Model
    model = model

    # Parser
    args = parse_args()

    #device
    os.environ['CUDA_VISIBLE_DEVICES'] = args.device
    device = torch.device('cuda')

    # dataset
    pose_training_set = meta_pose_regression_dataset(args.n_way, args.k_shot_support, args.k_shot_query, \
        args.data_path, types=args.datatypes)

    pose_valid_set = meta_pose_regression_dataset(args.n_way, args.k_shot_support, args.k_shot_query, \
        args.data_path, mode='val', types=args.datatypes)

    # dataloader
    train_loader = DataLoader(pose_training_set,
                              batch_size=args.task_size,
                              shuffle=True)
    valid_loader = DataLoader(pose_valid_set,
                              batch_size=args.task_size,
                              shuffle=True)

    # Print length of a episode
    print("length of episode : ", len(train_loader))

    if DEBUG:
        support_x, support_y, query_x, query_y = pose_training_set[0]

        print("=" * 25, "DEBUG", "=" * 25)
        print("support_x shape : ", support_x.shape)
        print("support_y shape : ", support_y.shape)
        print("query_x shape : ", query_x.shape)
        print("query_y shape : ", query_y.shape)

    # Set the optimizer
    optimizer = optim.Adam(model.parameters(), lr=config['meta_lr'])

    # Operator
    maml_operator = MAML_operator(model, device, train_loader, optimizer,
                                  args.epochs, save_model_path, valid_loader)
    maml_operator.train()

    # Save Model
    torch.save(
        model.state_dict(),
        os.path.join(save_model_path, "Model_{}.pt".format(args.datatypes)))
    print("=" * 20, "Save the model (After training)", "=" * 20)

    # Move saved files to the result folder
    remove_temp_files_and_move_directory(save_model_path, "/home/mgyukim/workspaces/result_MLwM", args.model, \
        config['encoder_type'], config['beta_kl'], "poseregression", args.datatypes, args.description)
Пример #2
0
def test(model,
         config,
         load_model_path,
         save_model_path,
         initializer=torch.nn.init.xavier_normal_):
    # Create Model
    model = model

    # Parser
    args = parse_args()

    #device
    os.environ['CUDA_VISIBLE_DEVICES'] = args.device
    device = torch.device('cuda')

    # dataset
    miniimagenet_test_set = meta_miniImagenet_dataset(args.n_way, args.k_shot_support, args.k_shot_query, \
        args.data_path, config['img_size'], mode='test', types=args.datatypes)

    test_loader = DataLoader(miniimagenet_test_set,
                             batch_size=args.task_size,
                             shuffle=True)

    if DEBUG:
        support_x, support_y, query_x, query_y = miniimagenet_test_set[0]

        print("=" * 25, "DEBUG", "=" * 25)
        print("support_x shape : ", support_x.shape)
        print("support_y shape : ", support_y.shape)
        print("query_x shape : ", query_x.shape)
        print("query_y shape : ", query_y.shape)

    # Load a model
    checkpoint = torch.load(load_model_path)
    model.load_state_dict(checkpoint)
    print("=" * 20, "Load the model : {}".format(load_model_path), "=" * 20)

    # Operator
    maml_operator = MAML_operator(model,
                                  device,
                                  test_loader,
                                  savedir=save_model_path)
    maml_operator.test()