Esempio n. 1
0
 def logit_y1x_src(self, x: tc.Tensor, n_mc_q: int = 0, repar: bool = True):
     dim_y = 2 if self.dim_y == 1 else self.dim_y
     y_eval = ds.expand_front(tc.arange(dim_y, device=x.device),
                              ds.tcsize_div(x.shape, self.shape_x))
     x_eval = ds.expand_middle(x, (dim_y, ), -len(self.shape_x))
     obs_xy = ds.edic({'x': x_eval, 'y': y_eval})
     if self.q_s1x is not None:
         logits = (
             self.q_s1x.expect(lambda dc: self.p_y1s.logp(dc, dc), obs_xy,
                               0, repar)  #, reducefn=tc.logsumexp)
         ) if n_mc_q == 0 else (
             self.q_s1x.expect(lambda dc: self.p_y1s.logp(dc, dc),
                               obs_xy,
                               n_mc_q,
                               repar,
                               reducefn=tc.logsumexp) - math.log(n_mc_q))
     else:
         vwei_p_y1s_logp = lambda dc: self.p_s.logp(
             dc, dc) - self.pt_s.logp(dc, dc) + self.p_y1s.logp(dc, dc)
         logits = (
             self.qt_s1x.expect(vwei_p_y1s_logp, obs_xy, 0,
                                repar)  #, reducefn=tc.logsumexp)
         ) if n_mc_q == 0 else (self.qt_s1x.expect(
             vwei_p_y1s_logp, obs_xy, n_mc_q, repar, reducefn=tc.logsumexp)
                                - math.log(n_mc_q))
     return (logits[..., 1] -
             logits[..., 0]).squeeze(-1) if self.dim_y == 1 else logits
Esempio n. 2
0
def get_frame(discr, gen, dc_vars, device=None, discr_src=None):
    if type(dc_vars) is not edic: dc_vars = edic(dc_vars)
    shape_x = dc_vars['shape_x'] if 'shape_x' in dc_vars else (
        dc_vars['dim_x'], )
    shape_s = discr.shape_s if hasattr(discr,
                                       "shape_s") else (dc_vars['dim_s'], )
    shape_v = discr.shape_v if hasattr(discr,
                                       "shape_v") else (dc_vars['dim_v'], )
    std_v1x = discr.std_v1x if hasattr(discr, "std_v1x") else dc_vars['qstd_v']
    std_s1vx = discr.std_s1vx if hasattr(discr,
                                         "std_s1vx") else dc_vars['qstd_s']
    std_s1x = discr.std_s1x if hasattr(discr, "std_s1x") else dc_vars['qstd_s']
    mode = dc_vars['mode']

    if mode.startswith("svgm"):
        q_args_stem = (discr.v1x, std_v1x, discr.s1vx, std_s1vx)
    elif mode.startswith("svae"):
        q_args_stem = (discr.s1x, std_s1x)
    else:
        return None
    if mode == "svgm-da2" and discr_src is not None:
        q_args = (
            discr_src.v1x,
            discr_src.std_v1x
            if hasattr(discr_src, "std_v1x") else dc_vars['qstd_v'],
            discr_src.s1vx,
            discr_src.std_s1vx
            if hasattr(discr_src, "std_s1vx") else dc_vars['qstd_s'],
        ) + q_args_stem
    elif mode == "svae-da2" and discr_src is not None:
        q_args = (
            discr_src.s1x,
            discr_src.std_s1x
            if hasattr(discr_src, "std_s1x") else dc_vars['qstd_s'],
        ) + q_args_stem
    elif mode in MODES_TWIST:  # svgm-ind, svae-da, svgm-da
        q_args = (None, ) * len(q_args_stem) + q_args_stem
    else:  # svae, svgm
        q_args = q_args_stem + (None, ) * len(q_args_stem)

    if mode.startswith("svgm"):
        frame = SemVar(
            shape_s, shape_v, shape_x, dc_vars['dim_y'], gen.x1sv,
            dc_vars['pstd_x'], discr.y1s, *q_args,
            *dc_vars.sublist(['mu_s', 'sig_s', 'mu_v', 'sig_v',
                              'corr_sv']), mode in MODES_DA,
            *dc_vars.sublist(['src_mvn_prior', 'tgt_mvn_prior']), device)
    elif mode.startswith("svae"):
        frame = SupVAE(shape_s, shape_x, dc_vars['dim_y'], gen.x1s,
                       dc_vars['pstd_x'], discr.y1s, *q_args,
                       *dc_vars.sublist(['mu_s', 'sig_s']), mode in MODES_DA,
                       *dc_vars.sublist(['src_mvn_prior',
                                         'tgt_mvn_prior']), device)
    return frame
Esempio n. 3
0
def get_models(archtype, dc_vars, ckpt=None, device=None):
    if type(dc_vars) is not edic: dc_vars = edic(dc_vars)
    discr = get_discr(archtype, dc_vars)
    if ckpt is not None: auto_load(locals(), 'discr', ckpt)
    discr.to(device)
    if dc_vars['mode'] in MODES_GEN:
        gen = get_gen(archtype, dc_vars, discr)
        if ckpt is not None: auto_load(locals(), 'gen', ckpt)
        gen.to(device)
        if dc_vars['mode'].endswith("-da2"):
            discr_src = get_discr(archtype, dc_vars)
            if ckpt is not None: auto_load(locals(), 'discr_src', ckpt)
            discr_src.to(device)
            frame = get_frame(discr, gen, dc_vars, device, discr_src)
            if ckpt is not None: auto_load(locals(), 'frame', ckpt)
            return discr, gen, frame, discr_src
        else:
            frame = get_frame(discr, gen, dc_vars, device)
            if ckpt is not None: auto_load(locals(), 'frame', ckpt)
            return discr, gen, frame
    else:
        return discr, None, None
Esempio n. 4
0
def main_stem(
        ag,
        ckpt,
        archtype,
        shape_x,
        dim_y,
        tr_src_loader,
        val_src_loader,
        ls_ts_tgt_loader=None,  # for ood
        tr_tgt_loader=None,
        ts_tgt_loader=None,
        testdom=None  # for da
):
    print(ag)
    print_infrstru_info()
    IS_OOD = is_ood(ag.mode)
    device = tc.device("cuda:" +
                       str(ag.gpu) if tc.cuda.is_available() else "cpu")

    # Datasets
    dim_x = tc.tensor(shape_x).prod().item()
    if IS_OOD: n_per_epk = len(tr_src_loader)
    else: n_per_epk = max(len(tr_src_loader), len(tr_tgt_loader))

    # Models
    res = get_models(archtype, edic(locals()) | vars(ag), ckpt, device)
    if ag.mode.endswith("-da2"):
        discr, gen, frame, discr_src = res
        discr_src.train()
    else:
        discr, gen, frame = res
    discr.train()
    if gen is not None: gen.train()

    # Methods and Losses
    if IS_OOD:
        lossfn = ood_methods(
            discr, frame, ag, dim_y, cnbb_actv="Sigmoid"
        )  # Actually the activation is ReLU, but there is no `is_treat` rule for ReLU in CNBB.
        domdisc = None
    else:
        lossfn, domdisc, dalossobj = da_methods(
            discr, frame, ag, dim_x, dim_y, device, ckpt,
            discr_src if ag.mode.endswith("-da2") else None)

    # Optimizer
    pgc = ParamGroupsCollector(ag.lr)
    pgc.collect_params(discr)
    if ag.mode.endswith("-da2"): pgc.collect_params(discr_src)
    if gen is not None: pgc.collect_params(gen, frame)
    if domdisc is not None: pgc.collect_params(domdisc)
    if ag.optim == "SGD":
        opt = getattr(tc.optim, ag.optim)(pgc.param_groups,
                                          weight_decay=ag.wl2,
                                          momentum=ag.momentum,
                                          nesterov=ag.nesterov)
        shrink_opt = ShrinkRatio(w_iter=ag.lr_wdatum * ag.n_bat,
                                 decay_rate=ag.lr_expo)
        lrsched = tc.optim.lr_scheduler.LambdaLR(opt, shrink_opt)
        auto_load(locals(), 'lrsched', ckpt)
    else:
        opt = getattr(tc.optim, ag.optim)(pgc.param_groups,
                                          weight_decay=ag.wl2)
    auto_load(locals(), 'opt', ckpt)

    # Training
    epk0 = 1
    i_bat0 = 1
    if ckpt is not None:
        epk0 = ckpt['epochs'][-1] + 1 if ckpt['epochs'] else 1
        i_bat0 = ckpt['i_bat']
    res = ResultsContainer(
        len(ag.testdoms) if IS_OOD else None, frame, ag, dim_y == 1, device,
        ckpt)
    print(f"Run in mode '{ag.mode}' for {ag.n_epk:3d} epochs:")
    try:
        for epk in range(epk0, ag.n_epk + 1):
            pbar = tqdm.tqdm(total=n_per_epk,
                             desc=f"Train epoch = {epk:3d}",
                             ncols=80,
                             leave=False)
            for i_bat, data_bat in enumerate(
                    tr_src_loader if IS_OOD else zip_longer(
                        tr_src_loader, tr_tgt_loader),
                    start=1):
                if i_bat < i_bat0: continue
                if IS_OOD:
                    x, y = data_bat
                    data_args = (x.to(device), y.to(device))
                else:
                    (x, y), (xt, yt) = data_bat
                    data_args = (x.to(device), y.to(device), xt.to(device))
                opt.zero_grad()
                if ag.mode in MODES_GEN:
                    n_iter_tot = (epk - 1) * n_per_epk + i_bat - 1
                    loss = lossfn(*data_args, n_iter_tot)
                else:
                    loss = lossfn(*data_args)
                loss.backward()
                opt.step()
                if ag.optim == "SGD": lrsched.step()
                pbar.update(1)
            # end for
            pbar.close()
            i_bat = 1
            i_bat0 = 1

            if epk % ag.eval_interval == 0:
                res.update(epk=epk, loss=loss.item())
                print(f"Mode '{ag.mode}': Epoch {epk:.1f}, Loss = {loss:.3e},")
                discr.eval()
                if ag.mode.endswith("-da2"):
                    discr_src.eval()
                    true_discr = discr_src
                elif ag.mode in MODES_TWIST and ag.true_sup_val:
                    true_discr = partial(frame.logit_y1x_src, n_mc_q=ag.n_mc_q)
                else:
                    true_discr = discr
                res.evaluate(true_discr, "val " + str(ag.traindom), 'val',
                             val_src_loader, 'src')
                if IS_OOD:
                    for i, (testdom, ts_tgt_loader) in enumerate(
                            zip(ag.testdoms, ls_ts_tgt_loader)):
                        res.evaluate(discr, "test " + str(testdom), 'ts',
                                     ts_tgt_loader, 'tgt', i)
                else:
                    res.evaluate(discr, "test " + str(testdom), 'ts',
                                 ts_tgt_loader, 'tgt')
                print()
                discr.train()
                if ag.mode.endswith("-da2"): discr_src.train()
        # end for
    except (KeyboardInterrupt, SystemExit):
        pass
    res.summary("val " + str(ag.traindom), 'val')
    if IS_OOD:
        for i, testdom in enumerate(ag.testdoms):
            res.summary("test " + str(testdom), 'ts', i)
    else:
        res.summary("test " + str(testdom), 'ts')

    if not ag.no_save:
        dirname = "ckpt_" + ag.mode + "/"
        os.makedirs(dirname, exist_ok=True)
        for i, testdom in enumerate(ag.testdoms if IS_OOD else [testdom]):
            filename = unique_filename(
                dirname + ("ood" if IS_OOD else "da"), ".pt",
                n_digits=3) if ckpt is None else ckpt['filename']
            dc_vars = edic(locals()).sub([
                'dirname', 'filename', 'testdom', 'shape_x', 'dim_x', 'dim_y',
                'i_bat'
            ]) | (edic(vars(ag)) - {'testdoms'}) | dc_state_dict(
                locals(), "discr", "opt")
            if ag.mode.endswith("-da2"):
                dc_vars.update(dc_state_dict(locals(), "discr_src"))
            if ag.optim == "SGD":
                dc_vars.update(dc_state_dict(locals(), "lrsched"))
            if ag.mode in MODES_GEN:
                dc_vars.update(dc_state_dict(locals(), "gen", "frame"))
            elif ag.mode in {"dann", "cdan", "dan", "mdd"}:
                dc_vars.update(dc_state_dict(locals(), "domdisc", "dalossobj"))
            else:
                pass
            if IS_OOD:
                dc_vars.update(
                    edic({
                        k: v
                        for k, v in res.dc.items() if not k.startswith('ls_')
                    }) | {
                        k[3:]: v[i]
                        for k, v in res.dc.items() if k.startswith('ls_')
                    })
            else:
                dc_vars.update(res.dc)
            tc.save(dc_vars, filename)
            print(f"checkpoint saved to '{filename}'.")
Esempio n. 5
0
def get_visual(
        ag,
        ckpt,
        archtype,
        shape_x,
        dim_y,
        tr_src_loader,
        val_src_loader,
        ls_ts_tgt_loader=None,  # for ood
        tr_tgt_loader=None,
        ts_tgt_loader=None,
        testdom=None  # for da
):
    print(ag)
    IS_OOD = is_ood(ag.mode)
    device = tc.device("cuda:" +
                       str(ag.gpu) if tc.cuda.is_available() else "cpu")

    # Datasets
    dim_x = tc.tensor(shape_x).prod().item()
    if IS_OOD: n_per_epk = len(tr_src_loader)
    else: n_per_epk = max(len(tr_src_loader), len(tr_tgt_loader))

    # Models
    res = get_models(archtype, edic(locals()) | vars(ag), ckpt, device)
    if ag.mode.endswith("-da2"):
        discr, gen, frame, discr_src = res
        discr_src.train()
    else:
        discr, gen, frame = res

    # get pictures
    discr.eval()
    if gen is not None: gen.eval()

    # Methods and Losses
    if IS_OOD:
        lossfn = ood_methods(
            discr, frame, ag, dim_y, cnbb_actv="Sigmoid"
        )  # Actually the activation is ReLU, but there is no `is_treat` rule for ReLU in CNBB.
        domdisc = None
    else:
        lossfn, domdisc, dalossobj = da_methods(
            discr, frame, ag, dim_x, dim_y, device, ckpt,
            discr_src if ag.mode.endswith("-da2") else None)

    epk0 = 1
    i_bat0 = 1
    if ckpt is not None:
        epk0 = ckpt['epochs'][-1] + 1 if ckpt['epochs'] else 1
        i_bat0 = ckpt['i_bat']
    res = ResultsContainer(
        len(ag.testdoms) if IS_OOD else None, frame, ag, dim_y == 1, device,
        ckpt)
    print(f"Run in mode '{ag.mode}' for {ag.n_epk:3d} epochs:")
    try:
        if ag.mode.endswith("-da2"):
            discr_src.eval()
            true_discr = discr_src
        elif ag.mode in MODES_TWIST and ag.true_sup_val:
            true_discr = partial(frame.logit_y1x_src, n_mc_q=ag.n_mc_q)
        else:
            true_discr = discr
        res.evaluate(true_discr, "val " + str(ag.traindom), 'val',
                     val_src_loader, 'src')

        if IS_OOD:
            for i, (testdom, ts_tgt_loader) in enumerate(
                    zip(ag.testdoms, ls_ts_tgt_loader)):
                res.evaluate(discr, "test " + str(testdom), 'ts',
                             ts_tgt_loader, 'tgt', i)
        else:
            res.evaluate(discr, "test " + str(testdom), 'ts', ts_tgt_loader,
                         'tgt')
            print()

        def batch_predict(images):
            import torch.nn.functional as F
            if tc.tensor(images[0]).size()[-1] == 3:
                images = [
                    tc.tensor(pic, dtype=tc.float).permute(2, 0, 1)
                    for pic in images
                ]
            batch = tc.stack(tuple(i for i in images), dim=0)
            batch = batch.to(device)

            logits = discr(batch)
            probs = F.softmax(logits, dim=1)

            return probs.detach().cpu().numpy()

        if IS_OOD:
            test_loader = ls_ts_tgt_loader[0]
        else:
            test_loader = ts_tgt_loader

        iter_tr, iter_ts = iter(tr_src_loader), iter(test_loader)
        train_batch, train_label = next(iter_tr)
        test_batch, test_label = next(iter_ts)

        os.makedirs(ag.mode, exist_ok=True)

        # search for the first accurate predict:
        cursor_train, cursor_test = 0, 0
        for i in range(400):
            cursor_test += 1
            if cursor_test >= test_batch.size()[0]:
                cursor_test = 0
                test_batch, test_label = next(iter_ts)
            while True:
                x_test = test_batch[cursor_test]
                test_pred = batch_predict([x_test])
                if cursor_test < test_batch.size()[0] and test_label[
                        cursor_test] == test_pred.squeeze().argmax():
                    break
                else:
                    cursor_test = cursor_test + 1
                    if cursor_test >= test_batch.size()[0]:
                        cursor_test = 0
                        test_batch, test_label = next(iter_ts)

            selected_pic, selected_label = test_batch[cursor_test], test_label[
                cursor_test]

            cursor_train += 1
            if cursor_train >= train_batch.size()[0]:
                cursor_train = 0
                train_batch, train_label = next(iter_tr)
            while True:
                x_train = train_batch[cursor_train]
                test_pred = batch_predict([x_train])
                if cursor_train < train_batch.size()[0] and train_label[
                        cursor_train] == test_pred.squeeze().argmax():
                    break
                else:
                    cursor_train = cursor_train + 1
                    if cursor_train >= train_batch.size()[0]:
                        cursor_train = 0
                        train_batch, train_label = next(iter_tr)

            selected_train_pic = train_batch[cursor_train]

            from lime import lime_image
            import numpy as np

            explainer = lime_image.LimeImageExplainer()
            explanation = explainer.explain_instance(
                np.array(selected_pic.permute(1, 2, 0), dtype=np.double),
                batch_predict,  # classification function
                top_labels=5,
                hide_color=0,
                num_samples=1000
            )  # number of images that will be sent to classification function
            from skimage.segmentation import mark_boundaries
            test_pic, mask_test_pic = explanation.get_image_and_mask(
                explanation.top_labels[0],
                positive_only=True,
                num_features=5,
                hide_rest=False)

            explanation_train = explainer.explain_instance(
                np.array(selected_train_pic.permute(1, 2, 0), dtype=np.double),
                batch_predict,  # classification function
                top_labels=5,
                hide_color=0,
                num_samples=1000
            )  # number of images that will be sent to classification function
            train_pic, mask_train_pic = explanation_train.get_image_and_mask(
                explanation_train.top_labels[0],
                positive_only=True,
                num_features=5,
                hide_rest=False)

            def vis_pic_trans(pic):
                pic = tc.tensor(pic).permute(2, 0, 1)
                invTrans = transforms.Compose([
                    transforms.Normalize(mean=[0., 0., 0.],
                                         std=[1 / 0.229, 1 / 0.224,
                                              1 / 0.225]),
                    transforms.Normalize(mean=[-0.485, -0.456, -0.406],
                                         std=[1., 1., 1.]),
                ])
                pic = invTrans(pic.unsqueeze(0)).squeeze()
                return pic.permute(1, 2, 0).numpy()

            test_pic = mark_boundaries(vis_pic_trans(test_pic), mask_test_pic)
            train_pic = mark_boundaries(vis_pic_trans(train_pic),
                                        mask_train_pic)

            import matplotlib.pyplot as plt

            plt.imshow(train_pic)
            plt.savefig(ag.mode + "/train-" + str(i) + ".png")
            plt.imshow(test_pic)
            plt.savefig(ag.mode + "/test-" + str(i) + ".png")

    except (KeyboardInterrupt, SystemExit):
        pass