def logit_y1x_src(self, x: tc.Tensor, n_mc_q: int = 0, repar: bool = True): dim_y = 2 if self.dim_y == 1 else self.dim_y y_eval = ds.expand_front(tc.arange(dim_y, device=x.device), ds.tcsize_div(x.shape, self.shape_x)) x_eval = ds.expand_middle(x, (dim_y, ), -len(self.shape_x)) obs_xy = ds.edic({'x': x_eval, 'y': y_eval}) if self.q_s1x is not None: logits = ( self.q_s1x.expect(lambda dc: self.p_y1s.logp(dc, dc), obs_xy, 0, repar) #, reducefn=tc.logsumexp) ) if n_mc_q == 0 else ( self.q_s1x.expect(lambda dc: self.p_y1s.logp(dc, dc), obs_xy, n_mc_q, repar, reducefn=tc.logsumexp) - math.log(n_mc_q)) else: vwei_p_y1s_logp = lambda dc: self.p_s.logp( dc, dc) - self.pt_s.logp(dc, dc) + self.p_y1s.logp(dc, dc) logits = ( self.qt_s1x.expect(vwei_p_y1s_logp, obs_xy, 0, repar) #, reducefn=tc.logsumexp) ) if n_mc_q == 0 else (self.qt_s1x.expect( vwei_p_y1s_logp, obs_xy, n_mc_q, repar, reducefn=tc.logsumexp) - math.log(n_mc_q)) return (logits[..., 1] - logits[..., 0]).squeeze(-1) if self.dim_y == 1 else logits
def get_frame(discr, gen, dc_vars, device=None, discr_src=None): if type(dc_vars) is not edic: dc_vars = edic(dc_vars) shape_x = dc_vars['shape_x'] if 'shape_x' in dc_vars else ( dc_vars['dim_x'], ) shape_s = discr.shape_s if hasattr(discr, "shape_s") else (dc_vars['dim_s'], ) shape_v = discr.shape_v if hasattr(discr, "shape_v") else (dc_vars['dim_v'], ) std_v1x = discr.std_v1x if hasattr(discr, "std_v1x") else dc_vars['qstd_v'] std_s1vx = discr.std_s1vx if hasattr(discr, "std_s1vx") else dc_vars['qstd_s'] std_s1x = discr.std_s1x if hasattr(discr, "std_s1x") else dc_vars['qstd_s'] mode = dc_vars['mode'] if mode.startswith("svgm"): q_args_stem = (discr.v1x, std_v1x, discr.s1vx, std_s1vx) elif mode.startswith("svae"): q_args_stem = (discr.s1x, std_s1x) else: return None if mode == "svgm-da2" and discr_src is not None: q_args = ( discr_src.v1x, discr_src.std_v1x if hasattr(discr_src, "std_v1x") else dc_vars['qstd_v'], discr_src.s1vx, discr_src.std_s1vx if hasattr(discr_src, "std_s1vx") else dc_vars['qstd_s'], ) + q_args_stem elif mode == "svae-da2" and discr_src is not None: q_args = ( discr_src.s1x, discr_src.std_s1x if hasattr(discr_src, "std_s1x") else dc_vars['qstd_s'], ) + q_args_stem elif mode in MODES_TWIST: # svgm-ind, svae-da, svgm-da q_args = (None, ) * len(q_args_stem) + q_args_stem else: # svae, svgm q_args = q_args_stem + (None, ) * len(q_args_stem) if mode.startswith("svgm"): frame = SemVar( shape_s, shape_v, shape_x, dc_vars['dim_y'], gen.x1sv, dc_vars['pstd_x'], discr.y1s, *q_args, *dc_vars.sublist(['mu_s', 'sig_s', 'mu_v', 'sig_v', 'corr_sv']), mode in MODES_DA, *dc_vars.sublist(['src_mvn_prior', 'tgt_mvn_prior']), device) elif mode.startswith("svae"): frame = SupVAE(shape_s, shape_x, dc_vars['dim_y'], gen.x1s, dc_vars['pstd_x'], discr.y1s, *q_args, *dc_vars.sublist(['mu_s', 'sig_s']), mode in MODES_DA, *dc_vars.sublist(['src_mvn_prior', 'tgt_mvn_prior']), device) return frame
def get_models(archtype, dc_vars, ckpt=None, device=None): if type(dc_vars) is not edic: dc_vars = edic(dc_vars) discr = get_discr(archtype, dc_vars) if ckpt is not None: auto_load(locals(), 'discr', ckpt) discr.to(device) if dc_vars['mode'] in MODES_GEN: gen = get_gen(archtype, dc_vars, discr) if ckpt is not None: auto_load(locals(), 'gen', ckpt) gen.to(device) if dc_vars['mode'].endswith("-da2"): discr_src = get_discr(archtype, dc_vars) if ckpt is not None: auto_load(locals(), 'discr_src', ckpt) discr_src.to(device) frame = get_frame(discr, gen, dc_vars, device, discr_src) if ckpt is not None: auto_load(locals(), 'frame', ckpt) return discr, gen, frame, discr_src else: frame = get_frame(discr, gen, dc_vars, device) if ckpt is not None: auto_load(locals(), 'frame', ckpt) return discr, gen, frame else: return discr, None, None
def main_stem( ag, ckpt, archtype, shape_x, dim_y, tr_src_loader, val_src_loader, ls_ts_tgt_loader=None, # for ood tr_tgt_loader=None, ts_tgt_loader=None, testdom=None # for da ): print(ag) print_infrstru_info() IS_OOD = is_ood(ag.mode) device = tc.device("cuda:" + str(ag.gpu) if tc.cuda.is_available() else "cpu") # Datasets dim_x = tc.tensor(shape_x).prod().item() if IS_OOD: n_per_epk = len(tr_src_loader) else: n_per_epk = max(len(tr_src_loader), len(tr_tgt_loader)) # Models res = get_models(archtype, edic(locals()) | vars(ag), ckpt, device) if ag.mode.endswith("-da2"): discr, gen, frame, discr_src = res discr_src.train() else: discr, gen, frame = res discr.train() if gen is not None: gen.train() # Methods and Losses if IS_OOD: lossfn = ood_methods( discr, frame, ag, dim_y, cnbb_actv="Sigmoid" ) # Actually the activation is ReLU, but there is no `is_treat` rule for ReLU in CNBB. domdisc = None else: lossfn, domdisc, dalossobj = da_methods( discr, frame, ag, dim_x, dim_y, device, ckpt, discr_src if ag.mode.endswith("-da2") else None) # Optimizer pgc = ParamGroupsCollector(ag.lr) pgc.collect_params(discr) if ag.mode.endswith("-da2"): pgc.collect_params(discr_src) if gen is not None: pgc.collect_params(gen, frame) if domdisc is not None: pgc.collect_params(domdisc) if ag.optim == "SGD": opt = getattr(tc.optim, ag.optim)(pgc.param_groups, weight_decay=ag.wl2, momentum=ag.momentum, nesterov=ag.nesterov) shrink_opt = ShrinkRatio(w_iter=ag.lr_wdatum * ag.n_bat, decay_rate=ag.lr_expo) lrsched = tc.optim.lr_scheduler.LambdaLR(opt, shrink_opt) auto_load(locals(), 'lrsched', ckpt) else: opt = getattr(tc.optim, ag.optim)(pgc.param_groups, weight_decay=ag.wl2) auto_load(locals(), 'opt', ckpt) # Training epk0 = 1 i_bat0 = 1 if ckpt is not None: epk0 = ckpt['epochs'][-1] + 1 if ckpt['epochs'] else 1 i_bat0 = ckpt['i_bat'] res = ResultsContainer( len(ag.testdoms) if IS_OOD else None, frame, ag, dim_y == 1, device, ckpt) print(f"Run in mode '{ag.mode}' for {ag.n_epk:3d} epochs:") try: for epk in range(epk0, ag.n_epk + 1): pbar = tqdm.tqdm(total=n_per_epk, desc=f"Train epoch = {epk:3d}", ncols=80, leave=False) for i_bat, data_bat in enumerate( tr_src_loader if IS_OOD else zip_longer( tr_src_loader, tr_tgt_loader), start=1): if i_bat < i_bat0: continue if IS_OOD: x, y = data_bat data_args = (x.to(device), y.to(device)) else: (x, y), (xt, yt) = data_bat data_args = (x.to(device), y.to(device), xt.to(device)) opt.zero_grad() if ag.mode in MODES_GEN: n_iter_tot = (epk - 1) * n_per_epk + i_bat - 1 loss = lossfn(*data_args, n_iter_tot) else: loss = lossfn(*data_args) loss.backward() opt.step() if ag.optim == "SGD": lrsched.step() pbar.update(1) # end for pbar.close() i_bat = 1 i_bat0 = 1 if epk % ag.eval_interval == 0: res.update(epk=epk, loss=loss.item()) print(f"Mode '{ag.mode}': Epoch {epk:.1f}, Loss = {loss:.3e},") discr.eval() if ag.mode.endswith("-da2"): discr_src.eval() true_discr = discr_src elif ag.mode in MODES_TWIST and ag.true_sup_val: true_discr = partial(frame.logit_y1x_src, n_mc_q=ag.n_mc_q) else: true_discr = discr res.evaluate(true_discr, "val " + str(ag.traindom), 'val', val_src_loader, 'src') if IS_OOD: for i, (testdom, ts_tgt_loader) in enumerate( zip(ag.testdoms, ls_ts_tgt_loader)): res.evaluate(discr, "test " + str(testdom), 'ts', ts_tgt_loader, 'tgt', i) else: res.evaluate(discr, "test " + str(testdom), 'ts', ts_tgt_loader, 'tgt') print() discr.train() if ag.mode.endswith("-da2"): discr_src.train() # end for except (KeyboardInterrupt, SystemExit): pass res.summary("val " + str(ag.traindom), 'val') if IS_OOD: for i, testdom in enumerate(ag.testdoms): res.summary("test " + str(testdom), 'ts', i) else: res.summary("test " + str(testdom), 'ts') if not ag.no_save: dirname = "ckpt_" + ag.mode + "/" os.makedirs(dirname, exist_ok=True) for i, testdom in enumerate(ag.testdoms if IS_OOD else [testdom]): filename = unique_filename( dirname + ("ood" if IS_OOD else "da"), ".pt", n_digits=3) if ckpt is None else ckpt['filename'] dc_vars = edic(locals()).sub([ 'dirname', 'filename', 'testdom', 'shape_x', 'dim_x', 'dim_y', 'i_bat' ]) | (edic(vars(ag)) - {'testdoms'}) | dc_state_dict( locals(), "discr", "opt") if ag.mode.endswith("-da2"): dc_vars.update(dc_state_dict(locals(), "discr_src")) if ag.optim == "SGD": dc_vars.update(dc_state_dict(locals(), "lrsched")) if ag.mode in MODES_GEN: dc_vars.update(dc_state_dict(locals(), "gen", "frame")) elif ag.mode in {"dann", "cdan", "dan", "mdd"}: dc_vars.update(dc_state_dict(locals(), "domdisc", "dalossobj")) else: pass if IS_OOD: dc_vars.update( edic({ k: v for k, v in res.dc.items() if not k.startswith('ls_') }) | { k[3:]: v[i] for k, v in res.dc.items() if k.startswith('ls_') }) else: dc_vars.update(res.dc) tc.save(dc_vars, filename) print(f"checkpoint saved to '{filename}'.")
def get_visual( ag, ckpt, archtype, shape_x, dim_y, tr_src_loader, val_src_loader, ls_ts_tgt_loader=None, # for ood tr_tgt_loader=None, ts_tgt_loader=None, testdom=None # for da ): print(ag) IS_OOD = is_ood(ag.mode) device = tc.device("cuda:" + str(ag.gpu) if tc.cuda.is_available() else "cpu") # Datasets dim_x = tc.tensor(shape_x).prod().item() if IS_OOD: n_per_epk = len(tr_src_loader) else: n_per_epk = max(len(tr_src_loader), len(tr_tgt_loader)) # Models res = get_models(archtype, edic(locals()) | vars(ag), ckpt, device) if ag.mode.endswith("-da2"): discr, gen, frame, discr_src = res discr_src.train() else: discr, gen, frame = res # get pictures discr.eval() if gen is not None: gen.eval() # Methods and Losses if IS_OOD: lossfn = ood_methods( discr, frame, ag, dim_y, cnbb_actv="Sigmoid" ) # Actually the activation is ReLU, but there is no `is_treat` rule for ReLU in CNBB. domdisc = None else: lossfn, domdisc, dalossobj = da_methods( discr, frame, ag, dim_x, dim_y, device, ckpt, discr_src if ag.mode.endswith("-da2") else None) epk0 = 1 i_bat0 = 1 if ckpt is not None: epk0 = ckpt['epochs'][-1] + 1 if ckpt['epochs'] else 1 i_bat0 = ckpt['i_bat'] res = ResultsContainer( len(ag.testdoms) if IS_OOD else None, frame, ag, dim_y == 1, device, ckpt) print(f"Run in mode '{ag.mode}' for {ag.n_epk:3d} epochs:") try: if ag.mode.endswith("-da2"): discr_src.eval() true_discr = discr_src elif ag.mode in MODES_TWIST and ag.true_sup_val: true_discr = partial(frame.logit_y1x_src, n_mc_q=ag.n_mc_q) else: true_discr = discr res.evaluate(true_discr, "val " + str(ag.traindom), 'val', val_src_loader, 'src') if IS_OOD: for i, (testdom, ts_tgt_loader) in enumerate( zip(ag.testdoms, ls_ts_tgt_loader)): res.evaluate(discr, "test " + str(testdom), 'ts', ts_tgt_loader, 'tgt', i) else: res.evaluate(discr, "test " + str(testdom), 'ts', ts_tgt_loader, 'tgt') print() def batch_predict(images): import torch.nn.functional as F if tc.tensor(images[0]).size()[-1] == 3: images = [ tc.tensor(pic, dtype=tc.float).permute(2, 0, 1) for pic in images ] batch = tc.stack(tuple(i for i in images), dim=0) batch = batch.to(device) logits = discr(batch) probs = F.softmax(logits, dim=1) return probs.detach().cpu().numpy() if IS_OOD: test_loader = ls_ts_tgt_loader[0] else: test_loader = ts_tgt_loader iter_tr, iter_ts = iter(tr_src_loader), iter(test_loader) train_batch, train_label = next(iter_tr) test_batch, test_label = next(iter_ts) os.makedirs(ag.mode, exist_ok=True) # search for the first accurate predict: cursor_train, cursor_test = 0, 0 for i in range(400): cursor_test += 1 if cursor_test >= test_batch.size()[0]: cursor_test = 0 test_batch, test_label = next(iter_ts) while True: x_test = test_batch[cursor_test] test_pred = batch_predict([x_test]) if cursor_test < test_batch.size()[0] and test_label[ cursor_test] == test_pred.squeeze().argmax(): break else: cursor_test = cursor_test + 1 if cursor_test >= test_batch.size()[0]: cursor_test = 0 test_batch, test_label = next(iter_ts) selected_pic, selected_label = test_batch[cursor_test], test_label[ cursor_test] cursor_train += 1 if cursor_train >= train_batch.size()[0]: cursor_train = 0 train_batch, train_label = next(iter_tr) while True: x_train = train_batch[cursor_train] test_pred = batch_predict([x_train]) if cursor_train < train_batch.size()[0] and train_label[ cursor_train] == test_pred.squeeze().argmax(): break else: cursor_train = cursor_train + 1 if cursor_train >= train_batch.size()[0]: cursor_train = 0 train_batch, train_label = next(iter_tr) selected_train_pic = train_batch[cursor_train] from lime import lime_image import numpy as np explainer = lime_image.LimeImageExplainer() explanation = explainer.explain_instance( np.array(selected_pic.permute(1, 2, 0), dtype=np.double), batch_predict, # classification function top_labels=5, hide_color=0, num_samples=1000 ) # number of images that will be sent to classification function from skimage.segmentation import mark_boundaries test_pic, mask_test_pic = explanation.get_image_and_mask( explanation.top_labels[0], positive_only=True, num_features=5, hide_rest=False) explanation_train = explainer.explain_instance( np.array(selected_train_pic.permute(1, 2, 0), dtype=np.double), batch_predict, # classification function top_labels=5, hide_color=0, num_samples=1000 ) # number of images that will be sent to classification function train_pic, mask_train_pic = explanation_train.get_image_and_mask( explanation_train.top_labels[0], positive_only=True, num_features=5, hide_rest=False) def vis_pic_trans(pic): pic = tc.tensor(pic).permute(2, 0, 1) invTrans = transforms.Compose([ transforms.Normalize(mean=[0., 0., 0.], std=[1 / 0.229, 1 / 0.224, 1 / 0.225]), transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1., 1., 1.]), ]) pic = invTrans(pic.unsqueeze(0)).squeeze() return pic.permute(1, 2, 0).numpy() test_pic = mark_boundaries(vis_pic_trans(test_pic), mask_test_pic) train_pic = mark_boundaries(vis_pic_trans(train_pic), mask_train_pic) import matplotlib.pyplot as plt plt.imshow(train_pic) plt.savefig(ag.mode + "/train-" + str(i) + ".png") plt.imshow(test_pic) plt.savefig(ag.mode + "/test-" + str(i) + ".png") except (KeyboardInterrupt, SystemExit): pass