Exemple #1
0
                                     sampling_strategy=None)

    valid_evaluator = Evaluator(
        dis_model,
        valid_studies,
        '../data/lumbar_train51_annotation.json',
        num_rep=5,
        max_dist=6,
    )

    step_per_batch = len(train_dataloader)
    optimizer = torch.optim.AdamW(dis_model.parameters(), lr=1e-5)
    max_step = 40 * step_per_batch
    fit_result = torch_utils.fit(
        dis_model,
        train_data=train_dataloader,
        valid_data=None,
        optimizer=optimizer,
        max_step=max_step,
        loss=NullLoss(),
        metrics=[valid_evaluator.metric],
        is_higher_better=True,
        evaluate_per_steps=step_per_batch,
        evaluate_fn=valid_evaluator,
    )

    dis_model.kp_model = None
    torch.save(dis_model.cpu().state_dict(),
               '../models/pretrained.dis_model_v3')
    print('task completed, {} seconds used'.format(time.time() - start_time))
Exemple #2
0
                                          prob_rotate=1,
                                          max_angel=180,
                                          num_rep=20,
                                          size=[512, 512],
                                          pin_memory=False)
    valid_dataloader = KeyPointDataLoader(valid_images,
                                          valid_spacings,
                                          valid_annotation,
                                          batch_size=1,
                                          num_workers=5,
                                          num_rep=20,
                                          size=[512, 512],
                                          pin_memory=False)

    optimizer = torch.optim.AdamW(kp_model_v2.parameters(), lr=1e-5)
    max_step = 50 * len(train_dataloader)
    result_v2 = torch_utils.fit(
        kp_model_v2,
        train_dataloader,
        valid_dataloader,
        optimizer,
        max_step,
        NullLoss(),
        [KeyPointAcc(6)],
        is_higher_better=True,
        evaluate_per_steps=len(train_dataloader),
        # checkpoint_dir='models',
    )
    torch.save(kp_model_v2.cpu().state_dict(), 'models/pretrained.kp_model_v2')
    print('task completed, {} seconds used'.format(time.time() - start_time))