def main(args):
    mnasnet = models.mnasnet1_0(pretrained=True).to(device).eval()
    cvae = CVAE(1000, 128, args.n_class * 2, args.n_class).to(device)
    cvae.encoder.eval()
    regressor = Regressor().to(device)
    if Path(args.cvae_resume_model).exists():
        print("load cvae model:", args.cvae_resume_model)
        cvae.load_state_dict(torch.load(args.cvae_resume_model))

    if Path(args.regressor_resume_model).exists():
        print("load regressor model:", args.regressor_resume_model)
        regressor.load_state_dict(torch.load(args.regressor_resume_model))

    image_label = pandas.read_csv(
        Path(args.data_root, args.metadata_file_name.format(
            args.subset))).sample(frac=1, random_state=551)[:250]
    image_label["class"] = image_label["class"] - 1

    dataset = WBCDataset(args.n_class,
                         image_label[:250].values,
                         args.data_root,
                         subset=args.subset,
                         train=True)
    data_loader = loader(dataset, args.batch_size, True)
    cvae_optimizer = RAdam(cvae.parameters(), weight_decay=1e-3)
    regressor_optimizer = RAdam(regressor.parameters(), weight_decay=1e-3)
    train(args, mnasnet, cvae, regressor, cvae_optimizer, regressor_optimizer,
          data_loader)
def main(args):
    mnasnet = models.mnasnet1_0(pretrained=True).to(device).eval()
    model = CVAE(1000, 128, 128, args.n_class, 128).to(device).eval()
    if Path(args.resume_model).exists():
        print("load regressor model:", args.resume_model)
        model.load_state_dict(torch.load(args.resume_model))

    image_label = pandas.read_csv(
        Path(args.data_root, 
             args.metadata_file_name.format(args.subset))
    ).sample(frac=1, random_state=551) #[250:]
    image_label["class"] = image_label["class"] - 1
    dataset = WBCDataset(image_label.values, args.data_root, subset=args.subset)
    data_loader = loader(dataset, 1, False)
    test(args, mnasnet, model, data_loader)