Example #1
0
def main():
    # init random seed
    init_random_seed(params.manual_seed)

    # Load dataset
    svhn_data_loader = get_svhn(split='train', download=True)
    svhn_data_loader_eval = get_svhn(split='test', download=True)
    mnist_data_loader = get_mnist(train=True, download=True)
    mnist_data_loader_eval = get_mnist(train=False, download=True)

    # Model init WDGRL
    tgt_encoder = model_init(Encoder(), params.encoder_wdgrl_path)
    critic = model_init(Discriminator(in_dims=params.d_in_dims,
                                      h_dims=params.d_h_dims,
                                      out_dims=params.d_out_dims),
                                        params.disc_wdgrl_path)
    clf = model_init(Classifier(), params.clf_wdgrl_path)

    # Train critic to optimality
    print("====== Training critic ======")
    if not (critic.pretrained and params.model_trained):
        critic = train_critic_wdgrl(tgt_encoder, critic, svhn_data_loader, mnist_data_loader)

    # Train target encoder
    print("====== Training encoder for both SVHN and MNIST domains ======")
    if not (tgt_encoder.pretrained and clf.pretrained and params.model_trained):
        tgt_encoder, clf = train_tgt_wdgrl(tgt_encoder, clf, critic,
                                     svhn_data_loader, mnist_data_loader, robust=False)

    # Eval target encoder on test set of target dataset
    print("====== Evaluating classifier for encoded SVHN and MNIST domains ======")
    print("-------- SVHN domain --------")
    eval_tgt(tgt_encoder, clf, svhn_data_loader_eval)
    print("-------- MNIST adaption --------")
    eval_tgt(tgt_encoder, clf, mnist_data_loader_eval)
def main():
    # init random seed
    init_random_seed(params.manual_seed)

    # Load dataset
    mnist_data_loader = get_mnist(train=True, download=True)
    mnist_data_loader_eval = get_mnist(train=False, download=True)
    usps_data_loader = get_usps(train=True, download=True)
    usps_data_loader_eval = get_usps(train=False, download=True)

    # Model init Revgard
    tgt_encoder = model_init(Encoder(), params.tgt_encoder_revgrad_path)
    critic = model_init(Discriminator(), params.disc_revgard_path)
    clf = model_init(Classifier(), params.clf_revgrad_path)

    # Train models
    print("====== Training source encoder and classifier in MNIST and USPS domains ======")
    if not (tgt_encoder.pretrained and clf.pretrained and critic.pretrained and params.model_trained):
        tgt_encoder, clf, critic = train_revgrad(tgt_encoder, clf, critic,
                                                 mnist_data_loader, usps_data_loader, robust=False)

    # Eval target encoder on test set of target dataset
    print("====== Evaluating classifier for encoded MNIST and USPS domain ======")
    print("-------- MNIST domain --------")
    eval_tgt(tgt_encoder, clf, mnist_data_loader_eval)
    print("-------- USPS adaption --------")
    eval_tgt(tgt_encoder, clf, usps_data_loader_eval)
def main():
    # init random seed
    init_random_seed(params.manual_seed)

    # Load dataset
    mnist_data_loader = get_mnist(train=True, download=True)
    mnist_data_loader_eval = get_mnist(train=False, download=True)
    usps_data_loader = get_usps(train=True, download=True)
    usps_data_loader_eval = get_usps(train=False, download=True)

    # Model init WDGRL
    tgt_encoder = model_init(Encoder(), params.encoder_wdgrl_rb_path)
    critic = model_init(Discriminator(), params.disc_wdgrl_rb_path)
    clf = model_init(Classifier(), params.clf_wdgrl_rb_path)

    # Train target encoder
    print("====== Robust Training encoder for both MNIST and USPS domains ======")
    if not (tgt_encoder.pretrained and clf.pretrained and params.model_trained):
        tgt_encoder, clf = train_tgt_wdgrl(tgt_encoder, clf, critic,
                                           mnist_data_loader, usps_data_loader, usps_data_loader_eval, robust=True)

    # Eval target encoder on test set of target dataset
    print("====== Evaluating classifier for encoded MNIST and USPS domains ======")
    print("-------- MNIST domain --------")
    eval_tgt_robust(tgt_encoder, clf, mnist_data_loader_eval)
    print("-------- USPS adaption --------")
    eval_tgt_robust(tgt_encoder, clf, usps_data_loader_eval)
def main():
    # init random seed
    init_random_seed(params.manual_seed)

    # Load dataset
    mnist_data_loader = get_usps(train=True, download=True)
    mnist_data_loader_eval = get_usps(train=False, download=True)
    usps_data_loader = get_usps(train=True, download=True)
    usps_data_loader_eval = get_usps(train=False, download=True)

    # Model init ADDA
    src_encoder = model_init(Encoder(), params.src_encoder_adda_rb_path)
    tgt_encoder = model_init(Encoder(), params.tgt_encoder_adda_rb_path)
    critic = model_init(Discriminator(), params.disc_adda_rb_path)
    clf = model_init(Classifier(), params.clf_adda_rb_path)

    # Train source model for adda
    print(
        "====== Robust training source encoder and classifier in MNIST domain ======"
    )
    if not (src_encoder.pretrained and clf.pretrained
            and params.model_trained):
        src_encoder, clf = train_src_robust(src_encoder, clf,
                                            mnist_data_loader)

    # Eval source model
    print("====== Evaluating classifier for MNIST domain ======")
    eval_tgt(src_encoder, clf, mnist_data_loader_eval)

    # Train target encoder
    print("====== Robust training encoder for USPS domain ======")
    # Initialize target encoder's weights with those of the source encoder
    if not tgt_encoder.pretrained:
        tgt_encoder.load_state_dict(src_encoder.state_dict())

    if not (tgt_encoder.pretrained and critic.pretrained
            and params.model_trained):
        tgt_encoder = train_tgt_adda(src_encoder,
                                     tgt_encoder,
                                     clf,
                                     critic,
                                     mnist_data_loader,
                                     usps_data_loader,
                                     usps_data_loader_eval,
                                     robust=True)

    # Eval target encoder on test set of target dataset
    print("====== Ealuating classifier for encoded USPS domain ======")
    print("-------- Source only --------")
    eval_tgt(src_encoder, clf, usps_data_loader_eval)
    print("-------- Domain adaption --------")
    eval_tgt(tgt_encoder, clf, usps_data_loader_eval)
Example #5
0
def main():
    # init random seed
    init_random_seed(params.manual_seed)

    # Load dataset
    svhn_data_loader = get_svhn(split='train', download=True)
    svhn_data_loader_eval = get_svhn(split='test', download=True)
    mnist_data_loader = get_mnist(train=True, download=True)
    mnist_data_loader_eval = get_mnist(train=False, download=True)

    # Model init DANN
    tgt_encoder = model_init(Encoder(), params.tgt_encoder_dann_rb_path)
    critic = model_init(
        Discriminator(in_dims=params.d_in_dims,
                      h_dims=params.d_h_dims,
                      out_dims=params.d_out_dims), params.disc_dann_rb_path)
    clf = model_init(Classifier(), params.clf_dann_rb_path)

    # Train models
    print(
        "====== Training source encoder and classifier in SVHN and MNIST domains ======"
    )
    if not (tgt_encoder.pretrained and clf.pretrained and critic.pretrained
            and params.model_trained):
        tgt_encoder, clf, critic = train_dann(tgt_encoder,
                                              clf,
                                              critic,
                                              svhn_data_loader,
                                              mnist_data_loader,
                                              mnist_data_loader_eval,
                                              robust=True)

    # Eval target encoder on test set of target dataset
    print(
        "====== Evaluating classifier for encoded SVHN and MNIST domains ======"
    )
    print("-------- SVHN domain --------")
    eval_tgt_robust(tgt_encoder, clf, svhn_data_loader_eval)
    print("-------- MNIST adaption --------")
    eval_tgt_robust(tgt_encoder, clf, mnist_data_loader_eval)
def _eval():

    args.val_save_pth = '/home/ozan/remoteDir/Tumor Bed Detection Results/Ynet_segmentation_ozan'
    args.raw_val_pth = '/home/ozan/remoteDir/'

    ' model setup '
    def activation(x):
        x
    model = eval('smp.'+args.model_name)(
        args.arch_encoder,
        encoder_weights='imagenet',
        classes=args.num_classes,
        activation=activation,
    )
    model.classifier = Classifier(model.encoder.out_shapes[0], args.num_classes)
    model.regressor = Regressor(model.encoder.out_shapes[0], 1)

    model, _, _ = networktools.continue_train(
        model,
        optimizers.optimfn(args.optim, model),
        args.eval_model_pth,
        True
    )

    ' datasets '
    validation_params = {
        'ph': args.tile_h * args.scan_resize,  # patch height (y)
        'pw': args.tile_w * args.scan_resize,  # patch width (x)
        'sh': args.tile_stride_h,     # slide step (dy)
        'sw': args.tile_stride_w,     # slide step (dx)
    }
    iterator_test = ds.Dataset_wsis(args.raw_val_pth, validation_params)

    model = model.cuda()

    val.predict_tumorbed(model, iterator_test, 0)
Example #7
0
import segmentation_models_pytorch as smp
from models import optimizers
from myargs import args
from models.models import Classifier, Regressor

' model setup '


def activation(x):
    x


model = eval('smp.' + args.model_name)(
    args.arch_encoder,
    encoder_weights='imagenet',
    classes=2,
    activation=activation,
)
model.classifier = Classifier(model.encoder.out_shapes[0], args.num_classes)
model.regressor = Regressor(model.encoder.out_shapes[0], 1)
optimizer = optimizers.optimfn(args.optim, model)

model, _, _ = networktools.continue_train(model, optimizer,
                                          args.eval_model_pth, True)

model = model.cuda()

dataset_path = '/home/ozan/Downloads/breastpathq-test/test_patches'
label_csv_path = '/home/ozan/Downloads/breastpathq-test/Results.csv'
val.predict_breastpathq(model, 391, dataset_path, label_csv_path)
def train():

    args.val_save_pth = 'data/val/out2'

    ' model setup '

    def activation(x):
        x

    model = eval('smp.' + args.model_name)(
        args.arch_encoder,
        encoder_weights='imagenet',
        classes=args.num_classes,
        activation=activation,
    )
    model.classifier = Classifier(model.encoder.out_shapes[0],
                                  args.num_classes)
    optimizer = optimizers.optimfn(args.optim, model)

    model, optimizer, start_epoch = networktools.continue_train(
        model, optimizer, args.train_model_pth, args.continue_train)
    ' losses '
    cls_weights_cls, cls_weights_seg = preprocessing.cls_weights(
        args.train_image_pth)

    params = {
        'reduction': 'mean',
        'alpha': torch.Tensor(cls_weights_cls),
        'xent_ignore': -1,
    }
    lossfn_cls = losses.lossfn(args.loss, params).cuda()

    params = {
        'reduction': 'mean',
        'alpha': torch.Tensor(cls_weights_seg),
        'xent_ignore': -1,
    }
    lossfn_seg = losses.lossfn(args.loss, params).cuda()

    ' datasets '
    validation_params = {
        'ph': args.tile_h * args.scan_resize,  # patch height (y)
        'pw': args.tile_w * args.scan_resize,  # patch width (x)
        'sh': args.tile_stride_h,  # slide step (dy)
        'sw': args.tile_stride_w,  # slide step (dx)
    }
    iterator_train = ds.GenerateIterator(args.train_image_pth,
                                         duplicate_dataset=1)
    iterator_val = ds.Dataset_wsis(args.raw_val_pth, validation_params)

    model = model.cuda()

    ' current run train parameters '
    print(args)

    for epoch in range(start_epoch, 1 + args.num_epoch):

        sum_loss = 0
        progress_bar = tqdm(iterator_train, disable=False)

        for batch_it, (image, label, is_cls,
                       cls_code) in enumerate(progress_bar):

            image = image.cuda()
            label = label.cuda()
            is_cls = is_cls.type(torch.bool).cuda()
            cls_code = cls_code.cuda()

            # pass images through the network (cls)
            encoding = model.encoder(image)

            loss = 0

            if torch.nonzero(is_cls).size(0) > 0:
                pred_cls = model.classifier(encoding[0][is_cls, ...])
                loss = loss + lossfn_cls(pred_cls, cls_code[is_cls])

            if torch.nonzero(~is_cls).size(0) > 0:
                pred_seg = model.decoder([x[~is_cls, ...] for x in encoding])
                loss = loss + lossfn_seg(pred_seg, label[~is_cls])

            sum_loss = sum_loss + loss.item()

            optimizer.zero_grad()
            loss.backward()
            #with amp_handle.scale_loss(loss, optimizer) as scaled_loss:
            #    scaled_loss.backward()
            optimizer.step()

            progress_bar.set_description('ep. {}, cls loss: {:.3f}'.format(
                epoch, sum_loss / (batch_it + args.epsilon)))

        ' test model accuracy '
        if args.validate_model > 0 and epoch % args.validate_model == 0:
            val.predict_wsis(model, iterator_val, epoch)

        if args.save_models > 0 and epoch % args.save_models == 0:
            state = {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'config': args
            }
            torch.save(
                state, '{}/model_{}_{}.pt'.format(args.model_save_pth,
                                                  args.arch_encoder, epoch))
def main():
    # init random seed
    init_random_seed(params.manual_seed)

    # Load dataset
    mnist_data_loader = get_mnist(train=True, download=True)
    mnist_data_loader_eval = get_mnist(train=False, download=True)
    usps_data_loader = get_usps(train=True, download=True)
    usps_data_loader_eval = get_usps(train=False, download=True)

    # Model init DANN
    tgt_encoder = model_init(Encoder(), params.tgt_encoder_dann_rb_path)
    critic = model_init(Discriminator(), params.disc_dann_rb_path)
    clf = model_init(Classifier(), params.clf_dann_rb_path)

    # Train models
    print(
        "====== Robust Training source encoder and classifier in MNIST and USPS domains ======"
    )
    if not (tgt_encoder.pretrained and clf.pretrained and critic.pretrained
            and params.model_trained):
        tgt_encoder, clf, critic = train_dann(tgt_encoder,
                                              clf,
                                              critic,
                                              mnist_data_loader,
                                              usps_data_loader,
                                              usps_data_loader_eval,
                                              robust=False)

    # Eval target encoder on test set of target dataset
    print(
        "====== Evaluating classifier for encoded MNIST and USPS domains ======"
    )
    print("-------- MNIST domain --------")
    eval_tgt_robust(tgt_encoder, clf, critic, mnist_data_loader_eval)
    print("-------- USPS adaption --------")
    eval_tgt_robust(tgt_encoder, clf, critic, usps_data_loader_eval)

    print("====== Pseudo labeling on USPS domain ======")
    pseudo_label(tgt_encoder, clf, "usps_train_pseudo", usps_data_loader)

    # Init a new model
    tgt_encoder = model_init(Encoder(), params.tgt_encoder_path)
    clf = model_init(Classifier(), params.clf_path)

    # Load pseudo labeled dataset
    usps_pseudo_loader = get_usps(train=True, download=True, get_pseudo=True)

    print("====== Standard training on USPS domain with pseudo labels ======")
    if not (tgt_encoder.pretrained and clf.pretrained):
        train_src_adda(tgt_encoder, clf, usps_pseudo_loader, mode='ADV')
    print("====== Evaluating on USPS domain with real labels ======")
    eval_tgt(tgt_encoder, clf, usps_data_loader_eval)

    tgt_encoder = model_init(Encoder(), params.tgt_encoder_rb_path)
    clf = model_init(Classifier(), params.clf_rb_path)
    print("====== Robust training on USPS domain with pseudo labels ======")
    if not (tgt_encoder.pretrained and clf.pretrained):
        train_src_robust(tgt_encoder, clf, usps_pseudo_loader, mode='ADV')
    print("====== Evaluating on USPS domain with real labels ======")
    eval_tgt(tgt_encoder, clf, usps_data_loader_eval)