def validation(model, val_loader):
    '''
        evaluation function for validation data
    '''

    print("Start validation.\n")
    val_loss_hist = Averager()

    with torch.no_grad():
        for images, targets, _ in val_loader:
            images = torch.stack(images).to(CFG.device).float()
            bboxes = [
                target['boxes'].to(CFG.device).float() for target in targets
            ]
            labels = [
                target['labels'].to(CFG.device).float() for target in targets
            ]
            batch_size = images.shape[0]

            target_res = dict()
            target_res['bbox'] = bboxes
            target_res['cls'] = labels
            # target_res["img_scale"] = torch.tensor([1.0] * batch_size, dtype=torch.float).to(CFG.device)
            # target_res["img_size"] = torch.tensor([images[0].shape[-2:]] * batch_size, dtype=torch.float).to(CFG.device)

            # forward pass & calculate loss
            output = model(images, target_res)
            loss_value = output['loss'].detach().item()
            val_loss_hist.update(loss_value, batch_size)

            del images, targets, bboxes

    return val_loss_hist.value
예제 #2
0
def evaluate_llhx(frame, input_loader, n_marg_llh, use_q_llh, mode, device):
    avgr = Averager()
    for x, y in input_loader:
        x = x.to(device)
        avgr.update(frame.llh(x, None, n_marg_llh, use_q_llh, mode),
                    nrep=len(x))
    return avgr.avg
예제 #3
0
    def train(self, epoch, dataloader):
        use_cuda = torch.cuda.is_available()
        self.net.train()

        for m in self.net.modules():  ### freeze是处理batchnorm的
            if isinstance(m, _BatchNorm):
                if self.args["train"]["freeze"]:
                    m.eval()

        loss_avg = Averager()
        lls_avg = Averager()

        for batch_idx, sample in enumerate(dataloader):

            sequence_idx = sample[0]
            data, target, name = sample[0], sample[1], sample[2]
            if use_cuda:
                data = data.cuda()
                target = target.cuda()

            if len(sample) > 3:
                args = sample[3:]
            else:
                args = None

            total_loss, loss_list = self.warp(
                data, target, False, args)  #### warp一般是并行,返回一组batch的损失
            loss_avg.update(total_loss.mean().detach().cpu().numpy())
            loss_list = tuple([l.cpu().numpy() for l in loss_list])
            info = "Finish training %d out of %d, " % (batch_idx + 1,
                                                       len(dataloader))
            for lid, l in enumerate(loss_list):
                info += "loss %d: %.4f, " % (lid, float(np.mean(l)))
            print(info)
            lls_avg.update(loss_list)
            self.optimizer.zero_grad()
            loss_scalar = torch.mean(total_loss)
            if self.half:  # Automated Mix Precision
                with amp.scale_loss(loss_scalar,
                                    self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss_scalar.backward()
            torch.nn.utils.clip_grad_value_(self.net.parameters(), 2)
            self.optimizer.step()
        self.scheduler.step(epoch=epoch)
        self.__writeLossLog(
            "Train",
            epoch,
            meanloss=loss_avg.val(),
            loss_list=lls_avg.val(),
            lr=self.scheduler.get_lr()[0],
        )
예제 #4
0
class ParticleVI(object):
    def __init__(self,
                 algo,
                 dataset,
                 kernel_fn,
                 base_model_fn,
                 num_particles=10,
                 resume=False,
                 resume_epoch=None,
                 resume_lr=1e-4):

        self.algo = algo
        self.dataset = dataset
        self.kernel_fn = kernel_fn
        self.num_particles = num_particles
        print("running {} on {}".format(algo, dataset))

        if self.dataset == 'regression':
            self.data = toy.generate_regression_data(80, 200)
            (self.train_data,
             self.train_targets), (self.test_data,
                                   self.test_targets) = self.data
        elif self.dataset == 'classification':
            self.train_data, self.train_targets = toy.generate_classification_data(
                100)
            self.test_data, self.test_targets = toy.generate_classification_data(
                200)
        else:
            raise NotImplementedError

        if kernel_fn == 'rbf':
            self.kernel = rbf_fn
        else:
            raise NotImplementedError

        models = [base_model_fn().cuda() for _ in range(num_particles)]

        self.models = models
        param_set, state_dict = extract_parameters(self.models)

        self.state_dict = state_dict
        self.param_set = torch.nn.Parameter(param_set.clone(),
                                            requires_grad=True)

        self.optimizer = torch.optim.Adam([{
            'params': self.param_set,
            'lr': 1e-3
        }])

        if self.dataset == 'regression':
            self.loss_fn = torch.nn.MSELoss()
        elif self.dataset == 'classification':
            self.loss_fn = torch.nn.CrossEntropyLoss()
        self.kernel_width_averager = Averager(shape=())

    def kernel_width(self, dist):
        """Update kernel_width averager and get latest kernel_width. """
        if dist.ndim > 1:
            dist = torch.sum(dist, dim=-1)
            assert dist.ndim == 1, "dist must have dimension 1 or 2."
        width, _ = torch.median(dist, dim=0)
        width = width / np.log(len(dist))
        self.kernel_width_averager.update(width)
        return self.kernel_width_averager.get()

    def rbf_fn(self, x, y):
        Nx = x.shape[0]
        Ny = y.shape[0]
        x = x.view(Nx, -1)
        y = y.view(Ny, -1)
        Dx = x.shape[1]
        Dy = y.shape[1]
        assert Dx == Dy
        diff = x.unsqueeze(1) - y.unsqueeze(0)  # [Nx, Ny, D]
        dist_sq = torch.sum(diff**2, -1)  # [Nx, Ny]
        h = self.kernel_width(dist_sq.view(-1))
        kappa = torch.exp(-dist_sq / h)  # [Nx, Nx]
        kappa_grad = torch.einsum('ij,ijk->ijk', kappa,
                                  -2 * diff / h)  # [Nx, Ny, D]
        return kappa, kappa_grad

    def svgd_grad(self, loss_grad, params):
        """
        Compute particle gradients via SVGD, empirical expectation
        evaluated by splitting half of the sampled batch. 
        """
        num_particles = params.shape[0]
        params2 = params.detach().requires_grad_(True)
        kernel_weight, kernel_grad = self.rbf_fn(params2, params)
        if kernel_grad is None:
            kernel_grad = torch.autograd.grad(kernel_weight.sum(), params2)[0]

        kernel_logp = torch.matmul(kernel_weight.t().detach(),
                                   loss_grad) / num_particles
        grad = kernel_logp - kernel_grad.mean(0)
        return grad

    def test(self, eval_loss=True):
        for model in self.models:
            model.eval()
        correct = 0
        test_loss = 0
        preds = []
        loss = 0
        test_data = self.test_data.cuda()
        test_targets = self.test_targets.cuda()
        for model in self.models:
            outputs = model(test_data)
            if eval_loss:
                loss += self.loss_fn(outputs, test_targets)
            else:
                loss += 0
            preds.append(outputs)

        preds = torch.stack(preds)
        p_mean = preds.mean(0)
        if self.dataset == 'classification':
            preds = torch.nn.functional.softmax(preds, dim=-1)
            preds = preds.mean(0)
            vote = preds.argmax(-1).cpu()
            correct = vote.eq(
                test_targets.cpu().data.view_as(vote)).float().cpu().sum()
            correct /= len(test_targets)
        else:
            correct = 0
            test_loss += (loss / self.num_particles)
        outputs_all = preds
        test_loss /= len(self.models)
        for model in self.models:
            model.train()
        return outputs_all, (test_loss, correct)

    def train(self, epochs):
        for epoch in range(0, epochs):
            loss_epoch = 0
            neglogp = torch.zeros(self.num_particles)
            insert_items(self.models, self.param_set, self.state_dict)
            neglogp_grads = torch.zeros_like(self.param_set)
            outputs = []
            for i, model in enumerate(self.models):
                train_data = self.train_data.cuda()
                train_targets = self.train_targets.cuda()
                output = model(train_data)
                outputs.append(output)
                loss = self.loss_fn(outputs[-1], train_targets)
                loss.backward()
                neglogp[i] = loss
                g = []
                for name, param in model.named_parameters():
                    g.append(param.grad.view(-1))
                neglogp_grads[i] = torch.cat(g)
                model.zero_grad()

            par_vi_grad = self.svgd_grad(neglogp_grads, self.param_set)
            self.optimizer.zero_grad()
            self.param_set.grad = par_vi_grad
            self.optimizer.step()

            loss_step = neglogp.mean()
            loss_epoch += loss_step

            loss_epoch /= self.num_particles
            print('Train Epoch {} [cum loss: {}]\n'.format(epoch, loss_epoch))

            if epoch % 100 == 0:
                insert_items(self.models, self.param_set, self.state_dict)
                with torch.no_grad():
                    outputs, stats = self.test(eval_loss=False)
                test_loss, correct = stats
                print('Test Loss: {}'.format(test_loss))
                print('Test Acc: {}%'.format(correct * 100))
                if self.dataset == 'regression':
                    toy.plot_regression(self.models,
                                        self.data,
                                        epoch,
                                        tag='svgd')
                if self.dataset == 'classification':
                    toy.plot_classification(self.models, epoch, tag='svgd')

            print('*' * 86)
def train(model, optimizer, scheduler, train_loader, val_loader):
    '''
        train model
    '''

    print("Start training.\n")
    early_stopping = EarlyStopping(patience=CFG.patience,
                                   verbose=True,
                                   trace_func=print)
    train_loss_hist = Averager()

    for epoch in tqdm(range(1, CFG.nepochs + 1)):
        now = time.localtime()
        print("%04d/%02d/%02d %02d:%02d:%02d" %
              (now.tm_year, now.tm_mon, now.tm_mday, now.tm_hour, now.tm_min,
               now.tm_sec))

        model.train()
        train_loss_hist.reset()

        for step, (images, targets, _) in enumerate(train_loader):
            images = torch.stack(images).to(CFG.device).float()
            bboxes = [
                target['boxes'].to(CFG.device).float() for target in targets
            ]
            labels = [
                target['labels'].to(CFG.device).float() for target in targets
            ]
            batch_size = images.shape[0]

            target_res = dict()
            target_res['bbox'] = bboxes
            target_res['cls'] = labels

            # forward pass & calculate loss
            loss = model(images, target_res)['loss']
            loss_value = loss.detach().item()
            train_loss_hist.update(loss_value, batch_size)

            # backward
            optimizer.zero_grad()  # reset previous gradient
            loss.backward()  # backward propagation
            optimizer.step()  # parameters update

            # log train loss by step
            if (step + 1) % 25 == 0:
                print('Epoch [{}/{}], Step [{}/{}], Train Loss: {:.6f}'.format(
                    epoch, CFG.nepochs, step + 1, len(train_loader),
                    loss_value))

            del images, targets, bboxes

        # scheduler step
        scheduler.step()

        # Print score after each epoch
        if ((epoch % CFG.print_freq) == 0) or (epoch == (CFG.nepochs)):
            val_loss = validation(model, val_loader)
            print("epoch:[%d] train_loss:[%.6f] val_loss:[%.6f]" %
                  (epoch, train_loss_hist.value, val_loss))

        # wandb.log({
        #     "Train Loss": train_loss,
        #     "Val Loss": val_loss,
        #     "Val mIoU": mIoU,
        #     "Val pix_acc": acc,
        #     "Seg": fig_mask,
        # })

        if early_stopping(model=model, val_loss=val_loss):
            best_metric = {'epoch': epoch, 'val_loss': val_loss}
            torch.save(
                model.state_dict(),
                os.path.join(
                    CFG.models_path, CFG.model_save_name,
                    f"{CFG.model_save_name}_{str(epoch).zfill(2)}.pt"))

        if early_stopping.early_stop or epoch == CFG.nepochs:
            print("best model information")
            print(f"epoch : {best_metric['epoch']}")
            print(f"val_loss : {best_metric['val_loss']}")
            break

    print("Done")
예제 #6
0
def main():
    torch.backends.cudnn.benchmark = True

    # hyper-params initializing
    args = dictobj()
    args.gpu = torch.device('cuda:%d' % (6))
    timestamp = '%d-%d-%d-%d-%d-%d-%d-%d-%d' % time.localtime(time.time())
    args.log_name = '%s-pointflow' % timestamp
    writer = SummaryWriter(comment=args.log_name)

    args.use_latent_flow, args.prior_w, args.entropy_w, args.recon_w = True, 1., 1., 1.
    args.fin, args.fz = 3, 128
    args.use_deterministic_encoder = True
    args.distributed = False
    args.optimizer = optim.Adam
    args.batch_size = 16
    args.lr, args.beta1, args.beta2, args.weight_decay = 1e-3, 0.9, 0.999, 1e-4
    args.T, args.train_T, args.atol, args.rtol = 1., False, 1e-5, 1e-5
    args.layer_type = diffop.CoScaleLinear
    args.solver = 'dopri5'
    args.use_adjoint, args.bn = True, False
    args.dims, args.num_blocks = (512, 512), 1  # originally (512 * 3)
    args.latent_dims, args.latent_num_blocks = (256, 256), 1

    args.resume, args.resume_path = False, None
    args.end_epoch = 2000
    args.scheduler, args.scheduler_step_size = optim.lr_scheduler.StepLR, 20
    args.random_rotation = True
    args.save_freq = 10

    args.dataset_type = 'shapenet15k'
    args.cates = ['airplane']  # 'all' for all categories training
    args.tr_max_sample_points, args.te_max_sample_points = 2048, 2048
    args.dataset_scale = 1.0
    args.normalize_per_shape = False
    args.normalize_std_per_axis = False
    args.num_workers = 4
    args.data_dir = "/data/ShapeNetCore.v2.PC15k"

    torch.cuda.set_device(args.gpu)
    model = PointFlow(**args).cuda(args.gpu)

    # load milestone
    epoch = 0
    optimizer = model.get_optimizer(**args)
    if args.resume:
        model, optimizer, epoch = resume(args.resume_path,
                                         model,
                                         optimizer,
                                         strict=True)
        print("Loaded model from %s" % args.resume_path)

    # load data
    train_dataset, test_dataset = get_datasets(args)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               sampler=None,
                                               drop_last=True)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              pin_memory=True,
                                              sampler=None,
                                              drop_last=False)

    if args.scheduler == optim.lr_scheduler.StepLR:
        scheduler = optim.lr_scheduler.StepLR(
            optimizer, step_size=args.scheduler_step_size, gamma=0.65)
    else:
        raise NotImplementedError("Only StepLR supported")

    ent_rec, latent_rec, recon_rec = Averager(), Averager(), Averager()
    for e in trange(epoch, args.end_epoch):
        # record lr
        if writer is not None:
            writer.add_scalar('lr/optimizer', scheduler.get_lr()[0], e)

        # feed a batch, train
        for idx, data in enumerate(tqdm(train_loader)):
            idx_batch, tr_batch, te_batch = data['idx'], data[
                'train_points'], data['test_points']
            model.train()
            if args.random_rotation:
                # raise NotImplementedError('Random Rotation Augmentation not implemented yet')
                tr_batch, _, _ = apply_random_rotation(
                    tr_batch, rot_axis=train_loader.dataset.gravity_axis)
            inputs = tr_batch.cuda(args.gpu, non_blocking=True)
            step = idx + len(train_loader) * e  # batch step
            out = model(inputs, optimizer, step, writer, sample_gpu=args.gpu)
            entropy, prior_nats, recon_nats = out['entropy'], out[
                'prior_nats'], out['recon_nats']
            ent_rec.update(entropy)
            recon_rec.update(recon_nats)
            latent_rec.update(prior_nats)

        # update lr
        scheduler.step(epoch=e)

        # save milestones
        if e % args.save_freq == 0 and e != 0:
            save(model, optimizer, e, path='milestone-%d.save' % e)
            save(model, optimizer, e,
                 path='milestone-latest.save' % e)  # save as latest model
예제 #7
0
def evaluate_acc(discr, input_loader, is_binary, device):
    avgr = Averager()
    for x, y in input_loader:
        x, y = x.to(device), y.to(device)
        avgr.update(acc_with_logits(discr, x, y, is_binary), nrep=len(y))
    return avgr.avg
예제 #8
0
    def validate(self, epoch, val_loader, save=False):


        """
        测试每张CT,每一个iter为CT中所有的小块
        :param val_loader -> Dataloader(): 所有测试的CT

        """
        startt = time.time()
        self.net.eval()
        # print(vars(self.ioer))
        if not self.args.val:
            if epoch % self.args.output['save_frequency'] == 0:
                self.ioer.save_file(self.net, epoch, self.args, 0)
            if epoch % self.args.output['val_frequency'] != 0:
                return
        
        loss_avg = Averager()
        lls_avg = Averager()
        em_avg = Averager()
        bs =  self.args.train['batch_size']
        if save:
            savedir = os.path.join(self.ioer.save_dir, '%03d' % epoch)
            if not os.path.exists(savedir):
                #shutil.rmtree(savedir)
                os.mkdir(savedir)

        ap_pred_list = []
        ap_gt_list = []
        with torch.no_grad():

            xbatch_filler = Batch_Filler(bs, size = [self.args.prepare['channel_input']] + self.args.prepare['crop_size'], dtype=self.dtype)
            bbox_left = None
            infos_list = []

            for sample_idx, tmp in enumerate(val_loader):
                # zhw: shape of data
                data, zhw, name, fullab = tmp[0]

                infos_list.append([zhw, name, fullab])
                x_left = torch.from_numpy(data).cuda()

                # 在一张CT中不断取bs大小的数据
                while x_left is not None:
                    isFull, xbatch, belong, idxlist, x_left = xbatch_filler.fill(sample_idx, x_left)

                    if len(val_loader) == sample_idx + 1:  # the last sample force execute test operation.
                        isFull = True
                        if len(idxlist) != bs:
                            fill_length = bs - len(idxlist)
                            for i in range(fill_length):
                                idxlist.append(-1)

                    if isFull:
                        data = xbatch
                        logits = self.warp(data, calc_loss=False)
                        logits = self.clipmargin(list(logits))

                        box_batch = MaskableList()
                        thresh_lists = []
                        for i_batch in range(data.shape[0]):
                            # 坐标&置信度
                            box_iter, thresh_list = decode_bbox(logits, thresh=-2, idx=i_batch, config=self.args)
                            box_batch.append(box_iter)
                            thresh_lists.append(thresh_list)
                    

                        for i_idx, idx in enumerate(idxlist):
                            if idx == -1:
                                break
                            zhw, name, fullab = infos_list[i_idx]
                            fullab = torch.from_numpy(fullab).cuda()
                            zhw = torch.from_numpy(zhw)
                            bbox_pieces = box_batch[belong==idx]

                            if bbox_left is not None:
                                bbox_pieces = bbox_left + bbox_pieces
                                bbox_left = None

                            # 由小块还原到原图的坐标
                            # comb_pred: 所有预测结果框 [n, 8] [z1,y1,x1,z2,y2,x2,confidence, cls]
                            comb_pred = val_loader.dataset.split_comb.combine(bbox_pieces, zhw)
                            # print(comb_pred.shape)
                            # print(fullab.shape)
                            # print(comb_pred)
                            # print(comb_pred.shape)
                            # print(fullab)

                            # 统计个数
                            em_list = []
                            if self.emlist is not None:
                                for em_fun in self.emlist:
                                    # 计算所有预测框的hit情况
                                    # fulllab: 所有gt box [n, 10] [z,y,x,dz,dy,dx,cls,1,1,1]
                                    # iou == 0.2
                                    em_result, iou_info = em_fun(comb_pred, fullab)
                                    # print(em_result)
                                    # exit()

                                    if len(comb_pred) > 0:
                                        comb_pred_tmp = comb_pred.cpu().numpy()
                                        # 预测置信度
                                        pred_probs = comb_pred_tmp[:, 6:7]
                                        # z轴区间为box大小
                                        bbox_size = comb_pred_tmp[:, 3:4] - comb_pred_tmp[:, 0:1]
                                        # [prob, diameter, hit-iou, [coords]]
                                        ap_pred_list.append(np.concatenate([pred_probs, bbox_size, iou_info[:, :1], comb_pred_tmp[:, :6]], axis=1))
                                        # print(ap_pred_list[0][1,2])
                                        # exit()
                                    else:
                                        ap_pred_list.append([])
                                    if fullab.shape[0] > 0:
                                        fullab_tmp = fullab.cpu().numpy()
                                        lab_size = fullab_tmp[:, 3:4]
                                        lab_center = fullab_tmp[:, :3]
                                        ap_gt_list.append(np.concatenate([lab_size, lab_center], axis=1))
                                    else:
                                        ap_gt_list.append([])
                                    em_list.append(em_result)
                                em_avg.update(tuple(em_list))

                            info = 'end %d out of %d, name %s, '%(idx, len(val_loader), name)
                            for lid, l in enumerate(em_list):
                                if isinstance(l,dict):
                                    for k,v in l.items():
                                        info += '%s: %.2f, '%(k, v)
                                else:
                                    info += '%d: %.2f, '%(lid, l)
                            threshs = np.array(thresh_lists).mean(axis=0)
                            info += 'thresh: '
                            for level, thresh in enumerate(threshs):
                                info += 'level %d= %.02f, '%(level, thresh)
                            print(info)
                            if save:
                                if isinstance(comb_pred, torch.Tensor):
                                    comb_pred = comb_pred.cpu().numpy()
                                try:
                                    np.save(os.path.join(savedir, name+'.npy'), np.concatenate([comb_pred, iou_info], axis=1))
                                except:
                                    print(name)
                        bbox_left_new = box_batch[(belong)==belong[-1]]
                        bbox_left, infos_list = restart_logit(x_left, bbox_left, bbox_left_new, infos_list)
                        xbatch_filler.restart()

        cPickle.dump(ap_pred_list, open(os.path.join(self.ioer.save_dir, 'ap_pred_list.pkl'), 'wb'))
        cPickle.dump(ap_gt_list, open(os.path.join(self.ioer.save_dir, 'ap_gt_list.pkl'), 'wb'))
        ap_small_bbox_list = []
        ap_big_bbox_list = []
        ap_small_gt_bbox_count = 0
        ap_big_gt_bbox_count = 0
        
        # 二者数量相同,CT个数
        assert len(ap_pred_list) == len(ap_gt_list)
        # 所有ct
        for idx, ap_pred in enumerate(ap_pred_list):
            ap_gt = ap_gt_list[idx]
            
            # 所有box
            for i, ap_pred_x in enumerate(ap_pred):
                # iou_info >= 0
                if ap_pred_x[2] >= 0:
                    bbox_size = ap_gt[int(ap_pred_x[2])][0]
                else:
                    bbox_size = ap_pred_x[1]

                # 小目标, 
                if bbox_size < self.small_size:
                    # [prob, id(ct)_id(lgt), size]
                    ap_small_bbox_list.append([ap_pred_x[0], str(idx) + '_' + str(int(ap_pred_x[2])), bbox_size])
                else:
                    ap_big_bbox_list.append([ap_pred_x[0], str(idx) + '_' + str(int(ap_pred_x[2])), bbox_size])
            
            for i, ap_gt_x in enumerate(ap_gt):
                bbox_size = ap_gt_x[0]
                if bbox_size < self.small_size:
                    ap_small_gt_bbox_count += 1
                else:
                    ap_big_gt_bbox_count += 1

        # import pdb; pdb.set_trace()
        # cal froc
        froc_val = self.froc(bbox_info=ap_big_bbox_list, 
                             gt_count=ap_big_gt_bbox_count, 
                             fps=[0.5, 1, 2, 4, 8], 
                             n_ct=len(val_loader))
        # print('FROC: {}'.format(froc_val))
        self.printf('FROC: ' + str(froc_val))

        rp_list = []
        # do with small & big box
        # for ap_bbox_list, ap_gt_bbox_count in zip([ap_small_bbox_list, ap_big_bbox_list], \
        #                                           [ap_small_gt_bbox_count, ap_big_gt_bbox_count]):
        ap_bbox_list = ap_big_bbox_list
        ap_gt_bbox_count = ap_big_gt_bbox_count
        recall_level = 1
        rp = {}

        # 按照prob排序
        # ap_bbox_list: [prob, id_cls, size]
        ap_bbox_list.sort(key=lambda x: -x[0])
        gt_bbox_hits = []
        pred_bbox_hit_count = 0
        for idx, ap_bbox in enumerate(ap_bbox_list):
            bbox_tag = ap_bbox[1]
            if not bbox_tag.endswith('-1'):
                pred_bbox_hit_count += 1
                # 如果有-1的框,会多记入一个
                if bbox_tag not in gt_bbox_hits:
                    gt_bbox_hits.append(bbox_tag)
            while len(gt_bbox_hits) / ap_gt_bbox_count >= recall_level*0.1 and recall_level <= 10:
                rp[recall_level] = [pred_bbox_hit_count / (idx + 1), ap_bbox[0]]
                recall_level += 1
        rp_list.append(rp)

        if self.emlist is not None:
            em_list = em_avg.val()
        endt = time.time()
        self.writeLossLog('Val', epoch, meanloss = 0, loss_list = [], em_list=em_list, time=(endt-startt)/60)
        # self.printf('small: ' + str(rp_list[0]))
        # self.printf('big: ' + str(rp_list[1]))
        self.printf('big: ' + str(rp_list))
예제 #9
0
    def train(self, epoch, dataloader):
        use_cuda = torch.cuda.is_available()
        self.net.train()

        for m in self.net.modules():
            if isinstance(m, _BatchNorm) or isinstance(m, ABN):
            # if isinstance(m, _BatchNorm) or isinstance(m, InPlaceABNSync) or isinstance(m,InPlaceABN):
                if self.args.train['freeze']:
                    m.eval()

        lr = self.getLR(epoch)
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

        loss_avg = Averager()
        lls_avg = Averager()

        startt = time.time()
        lastt0 = startt
        for batch_idx, (data, fpn_prob, fpn_coord_prob, fpn_coord_diff, fpn_diff, fpn_connects, names) in enumerate(dataloader):
            t0 = time.time()
            iter_time = t0-lastt0
            lastt0 = t0
            case_idxs = [dataloader.dataset.cases2idx[item] for item in names]
            if use_cuda:
                data = data.cuda()
                fpn_prob = [f.cuda() for f in fpn_prob]
                fpn_connects = [f.cuda() for f in fpn_connects]
                fpn_coord_prob = [f.cuda() for f in fpn_coord_prob]
                fpn_coord_diff = [f.cuda() for f in fpn_coord_diff]
                fpn_diff =  [f.cuda() for f in fpn_diff]
                case_idxs = torch.Tensor(case_idxs).cuda()

            losses, weights, pred_prob_list = self.warp(data, fpn_prob, fpn_coord_prob, fpn_coord_diff, fpn_diff, fpn_connects, case_idxs)
            # print(losses,'losses')
            # print(weights, 'weights')
            if pred_prob_list is not None:
                pred_prob_dict_pos = {}
                pred_prob_dict_neg = {}
                for pred_prob in pred_prob_list.cpu().numpy():
                    if pred_prob[0] == -1:
                        continue
                    case_idx, nodule_idx, n_weight = pred_prob
                    assert nodule_idx != 0
                    nodule_key = dataloader.dataset.cases[int(case_idx)] + '___' + str(abs(int(nodule_idx)))
                    if nodule_idx > 0:
                        if nodule_key not in pred_prob_dict_pos:
                            pred_prob_dict_pos[nodule_key] = n_weight
                        else:
                            pred_prob_dict_pos[nodule_key] = min(n_weight, pred_prob_dict_pos[nodule_key])
                    elif nodule_idx < 0:
                        if nodule_key not in pred_prob_dict_neg:
                            pred_prob_dict_neg[nodule_key] = n_weight
                        else:
                            pred_prob_dict_neg[nodule_key] = max(n_weight, pred_prob_dict_neg[nodule_key])
                for nodule_key, n_weight in pred_prob_dict_pos.items():
                    assert nodule_key in dataloader.dataset.sample_weights
                    dataloader.dataset.sample_weights[nodule_key][0] = n_weight
                    if n_weight > self.pos_weight_thresh:
                        dataloader.dataset.sample_weights[nodule_key][2] += 1
                        if dataloader.dataset.sample_weights[nodule_key][2] >= 3:
                            case_name, nodule_idx = nodule_key.split('___')
                            dataloader.dataset.lab_buffers[case_name][int(nodule_idx)-1][5] = 0

                #for nodule_key, n_weight in pred_prob_dict_neg.items():
                #    assert nodule_key in dataloader.dataset.neg_sample_weights
                #    dataloader.dataset.neg_sample_weights[nodule_key][0] = n_weight

            losses = losses.sum(dim=0)
            weights = weights.sum(dim=0)
            if weights.shape[0] > losses.shape[0]:
                assert weights.shape[0] == losses.shape[0] * 2
                fack_weights = weights[losses.shape[0]:]
                weights = weights[:losses.shape[0]]
            else:
                fack_weights = None
            total_loss = 0
            loss_list = []
            if fack_weights is not None:
                for l, w, fw in zip(losses, weights, fack_weights):
                    l_tmp = (l/ (1e-3+w))
                    total_loss += l_tmp
                    fack_l_tmp = (l/ (1e-3+fw))
                    loss_list.append(fack_l_tmp.detach().cpu().numpy())
            else:
                for l, w in zip(losses, weights):
                    l_tmp = (l/ (1e-3+w))
                    total_loss += l_tmp
                    loss_list.append(l_tmp.detach().cpu().numpy())

            loss_avg.update(total_loss.detach().cpu().numpy())
            info = 'end %d out of %d, '%(batch_idx, len(dataloader))

            

            for lid, l in enumerate(loss_list):
                info += 'loss %d: %.4f, '%(lid, np.mean(l))
            info += 'time: %.2f' %iter_time
            print(info)
            lls_avg.update(tuple(loss_list))
            self.optimizer.zero_grad()
            loss_scalar = total_loss
            if self.half:
                self.optimizer.backward(loss_scalar)
                #self.optimizer.clip_master_grads(1)
            else:
                loss_scalar.backward()
            if self.args.clip:
                torch.nn.utils.clip_grad_value_(self.warp.parameters(),1)
            self.optimizer.step()
        endt = time.time()
        self.writeLossLog('Train', epoch, meanloss = loss_avg.val(), loss_list = lls_avg.val(), lr = lr, time=(endt-startt)/60)

        return lls_avg.val()
예제 #10
0
class ParticleVI(object):
    def __init__(self,
                 algo,
                 dataset,
                 kernel_fn,
                 base_model_fn,
                 num_particles=10,
                 resume=False,
                 resume_epoch=None,
                 resume_lr=1e-4):

        self.algo = algo
        self.dataset = dataset
        self.kernel_fn = kernel_fn
        self.num_particles = num_particles
        print("running {} on {}".format(algo, dataset))

        if self.dataset == 'mnist':
            self.train_loader, self.test_loader, self.val_loader = datagen.load_mnist(
                split=True)
        elif self.dataset == 'cifar10':
            self.train_loader, self.test_loader, self.val_loader, = datagen.load_cifar10(
                split=True)
        else:
            raise NotImplementedError

        if kernel_fn == 'rbf':
            self.kernel = rbf_fn
            return_activations = False
        elif kernel_fn == 'cka':
            self.kernel = kernel_cka
            return_activations = True
        else:
            raise NotImplementedError

        models = [
            base_model_fn(num_classes=6,
                          return_activations=return_activations).cuda()
            for _ in range(num_particles)
        ]

        self.models = models
        param_set, state_dict = extract_parameters(self.models)

        self.state_dict = state_dict
        self.param_set = torch.nn.Parameter(param_set.clone(),
                                            requires_grad=True)

        self.optimizer = torch.optim.Adam([{
            'params': self.param_set,
            'lr': 1e-3,
            'weight_decay': 1e-4
        }])

        if resume:
            print('resuming from epoch {}'.format(resume_epoch))
            d = torch.load('saved_models/{}/{}2/model_epoch_{}.pt'.format(
                self.dataset, model_id, resume_epoch))
            for model, sd in zip(self.models, d['models']):
                model.load_state_dict(sd)
            self.param_set = d['params']
            self.state_dict = d['state_dict']
            self.optimizer = torch.optim.Adam([{
                'params': self.param_set,
                'lr': resume_lr,
                'weight_decay': 1e-4
            }])
            self.start_epoch = resume_epoch
        else:
            self.start_epoch = 0

        self.activation_length = self.models[0].activation_length
        self.loss_fn = torch.nn.CrossEntropyLoss()
        self.kernel_width_averager = Averager(shape=())

    def kernel_width(self, dist):
        """Update kernel_width averager and get latest kernel_width. """
        if dist.ndim > 1:
            dist = torch.sum(dist, dim=-1)
            assert dist.ndim == 1, "dist must have dimension 1 or 2."
        width, _ = torch.median(dist, dim=0)
        width = width / np.log(len(dist))
        self.kernel_width_averager.update(width)
        return self.kernel_width_averager.get()

    def svgd_grad(self, loss_grad, params):
        """
        Compute particle gradients via SVGD, empirical expectation
        evaluated by splitting half of the sampled batch. 
        """
        num_particles = params.shape[0]
        params2 = params.detach().requires_grad_(True)
        # kernel_weight, kernel_grad = self.kernel(params2, params, self.kernel_width)
        for i in range(num_particles):
            for j in range(num_particles):
                if i == j:
                    continue
                print(params[i].shape)
                k, _ = self.kernel(params[i], params[j])
                print(k.shape)
        if kernel_grad is None:
            kernel_grad = torch.autograd.grad(kernel_weight.sum(), params2)[0]

        kernel_logp = torch.matmul(kernel_weight.t().detach(),
                                   loss_grad) / num_particles
        grad = kernel_logp - kernel_grad.mean(0)
        return grad

    def test(self, test_loader, eval_loss=True):
        for model in self.models:
            model.eval()
        correct = 0
        test_loss = 0
        outputs_all = []
        for i, (inputs, targets) in enumerate(test_loader):
            preds = []
            loss = 0
            inputs = inputs.cuda()
            targets = targets.cuda()
            for model in self.models:
                outputs = model(inputs)
                #if self.kernel_fn == 'cka':
                # outputs, _ = outputs
                if eval_loss:
                    loss += self.loss_fn(outputs, targets)
                else:
                    loss += 0
                preds.append(torch.nn.functional.softmax(outputs, dim=-1))

            pred = torch.stack(preds)
            outputs_all.append(pred)
            preds = pred.mean(0)
            vote = preds.argmax(-1).cpu()
            correct += vote.eq(
                targets.cpu().data.view_as(vote)).float().cpu().sum()
            test_loss += (loss / self.num_particles)
        outputs_all = torch.cat(outputs_all, dim=1)
        test_loss /= i
        correct /= len(test_loader.dataset)
        for model in self.models:
            model.train()
        return outputs_all, (test_loss, correct)

    def train(self, epochs):
        for epoch in range(self.start_epoch, epochs):
            loss_epoch = 0
            for (inputs, targets) in self.train_loader:
                outputs = []
                activations = torch.zeros(len(self.models), len(targets),
                                          self.activation_length).cuda()
                neglogp = torch.zeros(self.num_particles)
                insert_items(self.models, self.param_set, self.state_dict)
                neglogp_grads = torch.zeros_like(self.param_set)
                for i, model in enumerate(self.models):
                    inputs = inputs.cuda()
                    targets = targets.cuda()
                    output, activation = model(inputs)
                    outputs.append(output)
                    activations[i, :, :] = activation
                    loss = self.loss_fn(outputs[-1], targets)
                    grad = torch.autograd.grad(loss.sum(), activation)[0]
                    print(grad)
                    print(grad.shape)
                    print(torch.count_nonzero(grad), np.prod(grad.shape))
                    loss.backward()
                    neglogp[i] = loss
                    g = []
                    for name, param in model.named_parameters():
                        g.append(param.grad.view(-1))
                    neglogp_grads[i] = torch.cat(g)
                    model.zero_grad()

                par_vi_grad = self.svgd_grad(neglogp_grads, self.param_set)
                self.optimizer.zero_grad()
                self.param_set.grad = par_vi_grad
                self.optimizer.step()

                loss_step = neglogp.mean()
                loss_epoch += loss_step

            loss_epoch /= self.num_particles
            print('Train Epoch {} [cum loss: {}]\n'.format(epoch, loss_epoch))

            if epoch % 1 == 0:
                insert_items(self.models, self.param_set, self.state_dict)
                with torch.no_grad():
                    outputs, stats = self.test(self.val_loader)
                    outputs2, _ = self.test(self.test_loader, eval_loss=False)
                test_loss, correct = stats
                print('Test Loss: {}'.format(test_loss))
                print('Test Acc: {}%'.format(correct * 100))

                uncertainties = uncertainty(outputs)
                entropy, variance = uncertainties
                uncertainties2 = uncertainty(outputs2)
                entropy2, variance2 = uncertainties2
                auc_entropy = auc_score(entropy, entropy2)
                auc_variance = auc_score(variance, variance2)

                print('Test AUC Entropy: {}'.format(auc_entropy))
                print('Test AUC Variance: {}'.format(auc_variance))

                params = {
                    'params': self.param_set,
                    'state_dict': self.state_dict,
                    'models': [m.state_dict() for m in self.models],
                    'optimizer': self.optimizer.state_dict()
                }
                save_dir = 'saved_models/{}/{}2/'.format(
                    self.dataset, model_id)
                fn = 'model_epoch_{}.pt'.format(epoch)
                print('saving model: {}'.format(fn))
                os.makedirs(save_dir, exist_ok=True)
                torch.save(params, save_dir + fn)
            print('*' * 86)
예제 #11
0
    def validate(self, epoch, val_loader):
        # 训练时,每隔一定的epoch且epoch>15进行验证
        if not self.args["val"]:
            if epoch % self.args["output"]['save_frequency'] != 0 and epoch > 0:
                return
            if 0 < epoch < 15:
                return
        self.net.eval()

        loss_avg = Averager()
        lls_avg = Averager()
        em_avg = Averager()
        bs = self.args["train"]["batch_size"]
        ###
        data_filler = BatchFiller(bs)
        target_filler = BatchFiller(bs)
        pred_filler = BatchFiller()
        full_target_filler = BatchFiller()

        val_results = []

        with torch.no_grad():
            pred_idx = 0
            use_cuda = torch.cuda.is_available()
            total_sample = len(val_loader)
            for sample_idx, sample in enumerate(
                    val_loader):  ### 不打乱顺序,根据batchsize来输入

                data, target, name = sample[0], sample[1], sample[2]

                if len(sample) > 3:
                    sequence_idx = sample[3]
                else:
                    args = None
                    sequence_idx = 0

                data = data.squeeze(0)
                target = target.squeeze(0)
                data_pieces, split_position = self.splitcomb.split(data)
                target_pieces, split_position = self.splitcomb.split(target)

                data_filler.enqueue(
                    sample=list(data_pieces),
                    name=[name[0] for _ in range(data_pieces.shape[0])],
                    shape=[
                        split_position for _ in range(data_pieces.shape[0])
                    ],
                    sequence_idx=[
                        sequence_idx for _ in range(data_pieces.shape[0])
                    ],
                )
                target_filler.enqueue(
                    sample=list(target_pieces),
                    name=[name[0] for _ in range(data_pieces.shape[0])],
                    shape=[
                        split_position for _ in range(data_pieces.shape[0])
                    ],
                    sequence_idx=[
                        sequence_idx for _ in range(data_pieces.shape[0])
                    ],
                )

                full_target_filler.enqueue(sample=[target],
                                           name=name,
                                           shape=[None],
                                           sequence_idx=[sequence_idx])

                if sample_idx + 1 == total_sample:
                    pad_num = max(bs - len(data_filler.sample_queue) % bs, 1)
                    data_filler.enqueue(
                        sample=[
                            np.zeros_like(data_pieces[0, :])
                            for _ in range(pad_num)
                        ],
                        name=["padding" for _ in range(pad_num)],
                        shape=[None for _ in range(pad_num)],
                        sequence_idx=[sequence_idx for _ in range(pad_num)],
                    )
                    target_filler.enqueue(
                        sample=[(np.zeros_like(target_pieces[0, :]))
                                for _ in range(pad_num)],
                        name=["padding" for _ in range(pad_num)],
                        shape=[None for _ in range(pad_num)],
                        sequence_idx=[sequence_idx for _ in range(pad_num)],
                    )
                    full_target_filler.enqueue(
                        sample=[np.zeros_like(target)],
                        name=["padding"],
                        shape=[None],
                        sequence_idx=[sequence_idx],
                    )
                while data_filler.isFull(mode="batch"):

                    data_batch, name, shape, sequence_idx = data_filler.dequeue(
                        mode="batch")
                    target_batch, name, shape, sequence_idx = target_filler.dequeue(
                        mode="batch")

                    if use_cuda:
                        data_batch = torch.from_numpy(
                            np.stack(data_batch, axis=0)).cuda()
                        target_batch = torch.from_numpy(
                            np.stack(target_batch, axis=0)).cuda()

                    total_loss, loss_list, logits = self.warp(
                        data_batch, target_batch, True, sequence_idx)

                    pred_filler.enqueue(sample=list(logits),
                                        name=name,
                                        shape=shape,
                                        sequence_idx=sequence_idx)
                    loss_avg.update(total_loss.mean().detach().cpu().numpy())
                    loss_list = tuple([l.cpu().numpy() for l in loss_list])
                    lls_avg.update(loss_list)

                while pred_filler.isFull(mode="sample"):

                    pred_full, _, shape, sequence_idx = pred_filler.dequeue(
                        mode="sample")
                    target_full, name, _, sequence_idx = full_target_filler.dequeue(
                        mode="sample")
                    pred_full = self.splitcomb.combine(pred_full, shape[0])
                    pred_full = choose_top1_connected_component(
                        model_pred=pred_full, choose_top1=self.choose_top1)
                    # pred_full = dynamic_choose_topk_vessel_connected_component(model_pred=pred_full, choose_topk=self.choose_topk)

                    em_list = []
                    if self.emlist is not None:
                        for em_fun in self.emlist:
                            em_list.extend(em_fun(pred_full, target_full[0]))
                        em_list = tuple(
                            [l.cpu().squeeze().numpy() for l in em_list])
                        em_avg.update(em_list)

                    curr_case_nii_path = os.path.join(self.testdir,
                                                      name[0]) + "_pred.nii.gz"
                    os.makedirs(os.path.dirname(curr_case_nii_path),
                                exist_ok=True)
                    self.writer.SetFileName(curr_case_nii_path)
                    self.writer.Execute(
                        sitk.GetImageFromArray(
                            (pred_full.cpu().squeeze(0).numpy()).astype(
                                np.uint8)))

                    curr_case_npy_path = os.path.join(self.save_dir, 'val_out',
                                                      '{}.npy'.format(name[0]))
                    os.makedirs(os.path.dirname(curr_case_npy_path),
                                exist_ok=True)
                    np.save(
                        curr_case_npy_path,
                        pred_full.cpu().squeeze(0).numpy().astype(np.uint8))

                    info = "Finish validation %d out of %d, name %s, " % (
                        pred_idx + 1,
                        len(val_loader),
                        name[0],
                    )
                    pred_idx += 1
                    for lid, l in enumerate(em_list):
                        info += "em %d: %.4f, " % (lid, l)
                    print(info)
                    val_results.append(info)

        if not self.args["val"]:
            if epoch % self.args["output"]["save_frequency"] == 0:
                self.ioer.save_file(self.net, epoch, self.args, 0)
            else:
                return

        if self.emlist is not None:
            em_list = em_avg.val()
        self.__writeLossLog(
            "Val",
            epoch,
            meanloss=loss_avg.val(),
            loss_list=lls_avg.val(),
            em_list=em_list,
        )

        with open(os.path.join(self.save_dir, '{}_val.txt'.format(epoch)),
                  'a') as f_out:
            f_out.write('\n'.join(val_results) + '\n\n')
예제 #12
0
class ParticleVI(object):
    def __init__(self,
                 algo,
                 dataset,
                 kernel_fn,
                 base_model_fn,
                 num_particles=50,
                 resume=False,
                 resume_epoch=None,
                 resume_lr=1e-4):

        self.algo = algo
        self.dataset = dataset
        self.kernel_fn = kernel_fn
        self.num_particles = num_particles
        print("running {} on {}".format(algo, dataset))

        self._use_wandb = False
        self._save_model = False

        if self.dataset == 'mnist':
            self.train_loader, self.test_loader, self.val_loader = datagen.load_mnist(
                split=True)
        elif self.dataset == 'cifar10':
            self.train_loader, self.test_loader, self.val_loader, = datagen.load_cifar10(
                split=True)
        else:
            raise NotImplementedError

        if kernel_fn == 'rbf':
            self.kernel = rbf_fn
        else:
            raise NotImplementedError

        models = [
            base_model_fn(num_classes=6).cuda() for _ in range(num_particles)
        ]

        self.models = models
        param_set, state_dict = extract_parameters(self.models)

        self.state_dict = state_dict
        self.param_set = torch.nn.Parameter(param_set.clone(),
                                            requires_grad=True)

        self.optimizer = torch.optim.Adam([{
            'params': self.param_set,
            'lr': 1e-3,
            'weight_decay': 1e-4
        }])

        if resume:
            print('resuming from epoch {}'.format(resume_epoch))
            d = torch.load('saved_models/{}/{}2/model_epoch_{}.pt'.format(
                self.dataset, model_id, resume_epoch))
            for model, sd in zip(self.models, d['models']):
                model.load_state_dict(sd)
            self.param_set = d['params']
            self.state_dict = d['state_dict']
            self.optimizer = torch.optim.Adam([{
                'params': self.param_set,
                'lr': resume_lr,
                'weight_decay': 1e-4
            }])
            self.start_epoch = resume_epoch
        else:
            self.start_epoch = 0

        loss_type = 'ce'
        if loss_type == 'ce':
            self.loss_fn = torch.nn.CrossEntropyLoss()
        elif loss_type == 'kliep':
            self.loss_fn = MattLoss().get_loss_dict()['kliep']
        self.kernel_width_averager = Averager(shape=())

        if self._use_wandb:
            wandb.init(project="open-category-experiments",
                       name="SVGD {}".format(self.dataset))
            for model in models:
                wandb.watch(model)
            config = wandb.config
            config.algo = algo
            config.dataset = dataset
            config.kernel_fn = kernel_fn
            config.num_particles = num_particles
            config.loss_fn = loss_type

    def kernel_width(self, dist):
        """Update kernel_width averager and get latest kernel_width. """
        if dist.ndim > 1:
            dist = torch.sum(dist, dim=-1)
            assert dist.ndim == 1, "dist must have dimension 1 or 2."
        width, _ = torch.median(dist, dim=0)
        width = width / np.log(len(dist))
        self.kernel_width_averager.update(width)
        return self.kernel_width_averager.get()

    def rbf_fn(self, x, y):
        Nx = x.shape[0]
        Ny = y.shape[0]
        x = x.view(Nx, -1)
        y = y.view(Ny, -1)
        Dx = x.shape[1]
        Dy = y.shape[1]
        assert Dx == Dy
        diff = x.unsqueeze(1) - y.unsqueeze(0)  # [Nx, Ny, D]
        dist_sq = torch.sum(diff**2, -1)  # [Nx, Ny]
        h = self.kernel_width(dist_sq.view(-1))
        kappa = torch.exp(-dist_sq / h)  # [Nx, Nx]
        kappa_grad = torch.einsum('ij,ijk->ijk', kappa,
                                  -2 * diff / h)  # [Nx, Ny, D]
        return kappa, kappa_grad

    def svgd_grad(self, loss_grad, params):
        """
        Compute particle gradients via SVGD, empirical expectation
        evaluated by splitting half of the sampled batch. 
        """
        num_particles = params.shape[0]
        params2 = params.detach().requires_grad_(True)
        kernel_weight, kernel_grad = self.rbf_fn(params2, params)
        if kernel_grad is None:
            kernel_grad = torch.autograd.grad(kernel_weight.sum(), params2)[0]

        kernel_logp = torch.matmul(kernel_weight.t().detach(),
                                   loss_grad) / num_particles
        grad = kernel_logp - kernel_grad.mean(0)
        return grad

    def test(self, test_loader, eval_loss=True):
        for model in self.models:
            model.eval()
        correct = 0
        test_loss = 0
        outputs_all = []
        for i, (inputs, targets) in enumerate(test_loader):
            preds = []
            loss = 0
            inputs = inputs.cuda()
            targets = targets.cuda()
            for model in self.models:
                outputs = model(inputs)
                if eval_loss:
                    loss += self.loss_fn(outputs, targets)
                else:
                    loss += 0
                preds.append(torch.nn.functional.softmax(outputs, dim=-1))

            pred = torch.stack(preds)
            outputs_all.append(pred)
            preds = pred.mean(0)
            vote = preds.argmax(-1).cpu()
            correct += vote.eq(
                targets.cpu().data.view_as(vote)).float().cpu().sum()
            test_loss += (loss / self.num_particles)
        outputs_all = torch.cat(outputs_all, dim=1)
        test_loss /= i
        correct /= len(test_loader.dataset)
        for model in self.models:
            model.train()
        return outputs_all, (test_loss, correct)

    def train(self, epochs):
        for epoch in range(self.start_epoch, epochs):
            loss_epoch = 0
            for (inputs, targets) in self.train_loader:
                outputs = []
                neglogp = torch.zeros(self.num_particles)
                insert_items(self.models, self.param_set, self.state_dict)
                neglogp_grads = torch.zeros_like(self.param_set)
                for i, model in enumerate(self.models):
                    inputs = inputs.cuda()
                    targets = targets.cuda()
                    output = model(inputs)
                    outputs.append(output)
                    loss = self.loss_fn(outputs[-1], targets)
                    loss.backward()
                    neglogp[i] = loss
                    g = []
                    for name, param in model.named_parameters():
                        g.append(param.grad.view(-1))
                    neglogp_grads[i] = torch.cat(g)
                    model.zero_grad()

                par_vi_grad = self.svgd_grad(neglogp_grads, self.param_set)
                self.optimizer.zero_grad()
                self.param_set.grad = par_vi_grad
                self.optimizer.step()

                loss_step = neglogp.mean()
                loss_epoch += loss_step

            loss_epoch /= self.num_particles
            print('Train Epoch {} [cum loss: {}]\n'.format(epoch, loss_epoch))

            if epoch % 1 == 0:
                insert_items(self.models, self.param_set, self.state_dict)
                with torch.no_grad():
                    outputs, stats = self.test(self.val_loader)
                    outputs2, _ = self.test(self.test_loader, eval_loss=False)
                test_loss, correct = stats
                print('Test Loss: {}'.format(test_loss))
                print('Test Acc: {}%'.format(correct * 100))

                uncertainties = uncertainty(outputs)
                entropy, variance = uncertainties
                uncertainties2 = uncertainty(outputs2)
                entropy2, variance2 = uncertainties2
                auc_entropy = auc_score(entropy, entropy2)
                auc_variance = auc_score(variance, variance2)

                print('Test AUC Entropy: {}'.format(auc_entropy))
                print('Test AUC Variance: {}'.format(auc_variance))

                if self._use_wandb:
                    wandb.log({"Test Loss": test_loss})
                    wandb.log({"Train Loss": loss_epoch})
                    wandb.log({"Test Acc": correct * 100})

                    wandb.log({"Test AUC (entropy)": auc_entropy})
                    wandb.log({"Test AUC (variance)": auc_variance})

                if self._save_model:
                    params = {
                        'params': self.param_set,
                        'state_dict': self.state_dict,
                        'models': [m.state_dict() for m in self.models],
                        'optimizer': self.optimizer.state_dict()
                    }
                    save_dir = 'saved_models/{}/{}2/'.format(
                        self.dataset, model_id)
                    fn = 'model_epoch_{}.pt'.format(epoch)
                    print('saving model: {}'.format(fn))
                    os.makedirs(save_dir, exist_ok=True)

            print('*' * 86)