Exemple #1
0
    def maskData(self, data):
        """

        Args:
            data:

        Returns:

        """

        msk = nib.load(self.mask)
        mskD = msk.get_data()
        if not np.all(np.bitwise_or(mskD == 0, mskD == 1)):
            raise ValueError("Mask has incorrect values.")
        # nVox = np.sum(mskD.flatten())
        if data.shape[0:3] != mskD.shape:
            raise ValueError((data.shape, mskD.shape))

        msk_f = mskD.flatten()
        msk_idx = np.where(msk_f == 1)[0]

        if len(data.shape) == 3:
            data_masked = data.flatten()[msk_idx]

        if len(data.shape) == 4:
            data = np.transpose(data, (3, 0, 1, 2))
            data_masked = np.zeros((data.shape[0], int(mskD.sum())))
            for i, x in enumerate(data):
                data_masked[i] = x.flatten()[msk_idx]

        img = data_masked

        return np.array(img)
def train(epoch, model, optimizer, train_loader, device, scaling, vlog, elog,
          log_var_std):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        data_flat = data.flatten(start_dim=1).repeat(1, scaling)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data_flat)
        loss, kl, rec = loss_function(recon_batch, data_flat, mu, logvar,
                                      log_var_std)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))
            # vlog.show_value(torch.mean(kl).item(), name="Kl-loss", tag="Losses")
            # vlog.show_value(torch.mean(rec).item(), name="Rec-loss", tag="Losses")
            # vlog.show_value(loss.item(), name="Total-loss", tag="Losses")

    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)))
Exemple #3
0
    def maskData(self, data):
        """

        Args:
            data:

        Returns:

        """

        msk = nib.load(self.mask)
        mskD = msk.get_data()
        if not np.all(np.bitwise_or(mskD == 0, mskD == 1)):
            raise ValueError("Mask has incorrect values.")
        # nVox = np.sum(mskD.flatten())
        if data.shape[0:3] != mskD.shape:
            raise ValueError((data.shape, mskD.shape))

        msk_f = mskD.flatten()
        msk_idx = np.where(msk_f == 1)[0]

        if len(data.shape) == 3:
            data_masked = data.flatten()[msk_idx]

        if len(data.shape) == 4:
            data = np.transpose(data, (3, 0, 1, 2))
            data_masked = np.zeros((data.shape[0], int(mskD.sum())))
            for i, x in enumerate(data):
                data_masked[i] = x.flatten()[msk_idx]

        img = data_masked

        return np.array(img)
Exemple #4
0
 def get_values(data_frame):
     n = len(data_frame)
     data = np.zeros((n, MAXLEN, 21), dtype=np.float32)
     for i, (_, row) in enumerate(data_frame.iterrows()):
         ind = row['indexes']
         m = len(row['indexes'])
         # s = (MAXLEN - m) // 2
         s = 0
         data[i, s:s + m, ind] = 1
     return data.flatten()
Exemple #5
0
def prep():

    images_path = f'{utilz.PARENT_DIR}data/images/'
    print(images_path)
    images = image_loader.Images(images_path, transforms=edits)

    data = images[0]
    dataloader = DataLoader(images, shuffle=True, batch_size=BATCH_SIZE)

    print(data.shape)
    dim = data.flatten().shape[0]
    if USE_LOGGER:
        writer = SummaryWriter(
            f"runs/image_gen_test_MID_{MIDDLE}_BOTTLE_{BOTTLENECK}_{TIME}")
    else:
        writer = None

    model = models.VAEConv2d(dim, middle=MIDDLE,
                             bottleneck=BOTTLENECK).to(device)
    print(model)

    if LOAD_MODEL:
        model = utilz.load_model(model, MODEL_FN)

    if LR is not None:
        optimizer = optim.Adam(model.parameters(), lr=LR)
    else:
        optimizer = optim.Adam(model.parameters())

    write_to = f'samples/images/image_gen_{TIME}'

    os.makedirs(write_to)
    d = {
        'write_to': write_to,
        'writer': writer,
        'dataloader': dataloader,
        'model': model,
        'optimizer': optimizer,
        'set': images,
        'model_fn': MODEL_FN
    }
    return d
Exemple #6
0
def plot_log_dist(data,
                  logger,
                  name="undefined",
                  save_path="./",
                  cumulative=False,
                  hist=True,
                  kde=False,
                  bins=60):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    if (isinstance(data, torch.Tensor)):
        data = data.numpy()
    data = data.flatten()
    data_log = []
    count_0 = 0
    count = data.shape[0]
    for dat in data:
        if (dat > 0.0):
            data_log.append(math.log(dat, 10))
        elif (dat < 0.0):
            print("%s is not a positive data" % (name))
            return False
        else:
            count_0 += 1
    data_log = np.array(data_log)
    note = "non-zero rate: %.4f(%d/%d)" % (
        (count - count_0) / count, count - count_0, count)
    if count_0 == count:
        logger.write("%s is all-zero." % (name))
        return True
    plot_dist(data=data_log,
              logger=logger,
              name=name + "(log)",
              save_dir=save_dir,
              notes=note,
              cumulative=cumulative,
              hist=hist,
              kde=kde,
              bins=bins)
    return True
def update_model(participate, n_total):
    """
    Update local model with all models selected clients.
    Args:
        participate: if to participate this round of averaging.
        n_total: total number of samples used in this averaging round.

    Returns:

    """
    global model, client_socket, worker_number, args, dataset_length
    parameters = model.state_dict()
    manager = multiprocessing.Manager()
    return_dict = manager.dict()
    for key in parameters.keys():
        data = parameters[key].cpu().numpy()
        shape = parameters[key].shape
        if participate == 0:
            data = np.zeros(shape=shape)
        data = data.flatten()
        ratio = dataset_length / n_total
        data = data * ratio
        if participate == 1:
            print("Worker {} conducting secret sharing on key {}".format(worker_number, key))
        p = multiprocessing.Process(target=get_secret_sum, args=(client_socket, data, return_dict))
        p.start()
        p.join()
        result = np.reshape(return_dict[0], shape)
        if args.gpu == -1:
            sum_result = torch.from_numpy(result).cpu()
        else:
            sum_result = torch.from_numpy(result).cuda()
        parameters[key] = sum_result
        if participate == 1:
            print("Worker {} finished secret sharing on key {}".format(worker_number, key))

    model.load_state_dict(parameters)
Exemple #8
0
    def __init__(self,
                 memmap_item,
                 indices,
                 dimension,
                 mean=None,
                 var=None,
                 min_=None,
                 max_=None,
                 normalize=False,
                 normalization_type="global_scale"):

        self.memmap_item = memmap_item
        self.indices = list(indices)
        self.dimension = dimension
        self.normalize = normalize
        self.ntype = normalization_type
        self.max_ = max_
        self.min_ = min_
        self.mean = mean
        self.var = var

        if self.normalize:
            if self.ntype == "global_scale":

                if (self.max_ is None) or (self.min_ is None):
                    print("Computing global max and min ... ")

                    self.max_ = float(np.finfo(np.float32).min)
                    self.min_ = float(np.finfo(np.float32).max)

                    for index in indices:
                        data = self.memmap_item[index]
                        max_ = float(np.amax(data.flatten()))
                        min_ = float(np.amin(data.flatten()))

                        if self.max_ < max_:
                            self.max_ = max_

                        if self.min_ > min_:
                            self.min_ = min_

                    print("Obtained global max and min ", self.max_, self.min_)

            elif self.ntype == "global":

                if (self.mean is None) or (self.var is None):
                    print("Computing global mean and standard deviation")

                    mean = 0
                    num_items = 0

                    for index in indices:
                        data = self.memmap_item[index]
                        mean, num_items = update_mean(mean,
                                                      num_items,
                                                      data,
                                                      lambda x: x,
                                                      offset=0)

                    var = 0
                    num_items = 0

                    for index in indices:
                        data = self.memmap_item[index]
                        var, num_items = update_mean(var,
                                                     num_items,
                                                     data,
                                                     lambda x: x**2,
                                                     offset=mean)

                    var = var**(0.5)

                    print("Obtained global mean and standard deviation", mean,
                          var)

                self.mean = mean
                self.var = var
Exemple #9
0
def plot_dist(data,
              logger,
              name="undefined",
              ran=[],
              save_path="./",
              bins=100,
              kde=False,
              hist=True,
              notes="",
              redo=False,
              rug_max=False,
              stat=True,
              cumulative=False):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    if not isinstance(data, np.ndarray):
        data = np.array(data)
    data = data.flatten()
    data = np.sort(data)
    stat_data = "mean=%.4f var=%.4f, min=%.4f, max=%.4f mid=%.4f" % (np.mean(
        data), np.var(data), np.min(data), np.max(data), np.median(data))

    num_max = 100000
    if data.shape[0] > num_max:
        print("data is too large. sample %d elements.")
        data = np.array(random.sample(data.tolist(), num_max))
    logger.write(stat_data)
    img_dir = {}

    y_interv = 0.05
    y_line = 0.95
    #hist_plot

    if hist:
        sig = True
        sig_check = False
        while sig:
            try:
                title = name + " distribution"
                if (cumulative == True):
                    title += "(cumu)"
                hist_dir = save_dir + title + ".jpg"
                img_dir["hist"] = hist_dir
                if (redo == False and os.path.exists(hist_dir)):
                    print("image already exists.")
                else:
                    plt.figure()
                    if (ran != []):
                        plt.xlim(ran[0], ran[1])
                    else:
                        set_lim(data, None, plt)
                    #sns.distplot(data, bins=bins, color='b',kde=False)
                    #method="sns"
                    plt.hist(data,
                             bins=bins,
                             color="b",
                             density=True,
                             cumulative=cumulative)
                    method = "hist"

                    plt.title(title, fontsize=large_fontsize)
                    if (stat == True):
                        plt.annotate(stat_data,
                                     xy=(0.02, y_line),
                                     xycoords='axes fraction')
                        y_line -= y_interv
                    print_notes(notes, y_line, y_interv)
                    plt.savefig(hist_dir)
                    plt.close()
                sig = False
            except Exception:
                if (sig_check == True):
                    raise Exception("exception in plot dist.")
                data, note = check_array_1d(data, logger)
                if (isinstance(notes, str)):
                    notes = [notes, note]
                elif (isinstance(notes, list)):
                    notes = notes + [note]
                sig_check = True
    try:  #kde_plot
        y_line = 0.95
        if kde:
            title = name + " dist(kde)"
            if (cumulative == True):
                title += "(cumu)"
            kde_dir = save_dir + title + ".jpg"
            img_dir["kde"] = kde_dir
            if (redo == False and os.path.exists(kde_dir)):
                print("image already exists.")
            else:
                plt.figure()
                if (ran != []):
                    plt.xlim(ran[0], ran[1])
                else:
                    set_lim(data, None, plt)
                sns.kdeplot(data, shade=True, color='b', cumulative=cumulative)
                if rug_max:
                    sns.rugplot(data[(data.shape[0] -
                                      data.shape[0] // 500):-1],
                                height=0.2,
                                color='r')
                plt.title(title, fontsize=large_fontsize)
                if stat:
                    plt.annotate(stat_data,
                                 xy=(0.02, y_line),
                                 xycoords='axes fraction')
                    y_line -= 0.05
                print_notes(notes, y_line, y_interv)
                plt.savefig(kde_dir)
                plt.close()
    except Exception:
        print("exception in kde plot %s" % (kde_dir))
    return img_dir
def test(model, test_loader, test_loader_abnorm, device, scaling, vlog, elog,
         image_size, batch_size, log_var_std):
    model.eval()
    test_loss = []
    kl_loss = []
    rec_loss = []
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            data_flat = data.flatten(start_dim=1).repeat(1, scaling)
            recon_batch, mu, logvar = model(data_flat)
            loss, kl, rec = loss_function(recon_batch, data_flat, mu, logvar,
                                          log_var_std)
            test_loss += (kl + rec).tolist()
            kl_loss += kl.tolist()
            rec_loss += rec.tolist()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([
                    data[:n],
                    recon_batch[:, :image_size].view(batch_size, 1, 28, 28)[:n]
                ])
                # vlog.show_image_grid(comparison.cpu(),   name='reconstruction')

    # vlog.show_value(np.mean(kl_loss), name="Norm-Kl-loss", tag="Anno")
    # vlog.show_value(np.mean(rec_loss), name="Norm-Rec-loss", tag="Anno")
    # vlog.show_value(np.mean(test_loss), name="Norm-Total-loss", tag="Anno")
    # elog.show_value(np.mean(kl_loss), name="Norm-Kl-loss", tag="Anno")
    # elog.show_value(np.mean(rec_loss), name="Norm-Rec-loss", tag="Anno")
    # elog.show_value(np.mean(test_loss), name="Norm-Total-loss", tag="Anno")

    test_loss_ab = []
    kl_loss_ab = []
    rec_loss_ab = []
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader_abnorm):
            data = data.to(device)
            data_flat = data.flatten(start_dim=1).repeat(1, scaling)
            recon_batch, mu, logvar = model(data_flat)
            loss, kl, rec = loss_function(recon_batch, data_flat, mu, logvar,
                                          log_var_std)
            test_loss_ab += (kl + rec).tolist()
            kl_loss_ab += kl.tolist()
            rec_loss_ab += rec.tolist()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([
                    data[:n],
                    recon_batch[:, :image_size].view(batch_size, 1, 28, 28)[:n]
                ])
                # vlog.show_image_grid(comparison.cpu(),                                     name='reconstruction2')

    print('====> Test set loss: {:.4f}'.format(np.mean(test_loss)))

    # vlog.show_value(np.mean(kl_loss_ab), name="Unorm-Kl-loss", tag="Anno")
    # vlog.show_value(np.mean(rec_loss_ab), name="Unorm-Rec-loss", tag="Anno")
    # vlog.show_value(np.mean(test_loss_ab), name="Unorm-Total-loss", tag="Anno")
    # elog.show_value(np.mean(kl_loss_ab), name="Unorm-Kl-loss", tag="Anno")
    # elog.show_value(np.mean(rec_loss_ab), name="Unorm-Rec-loss", tag="Anno")
    # elog.show_value(np.mean(test_loss_ab), name="Unorm-Total-loss", tag="Anno")

    kl_roc, kl_pr = elog.get_classification_metrics(
        kl_loss + kl_loss_ab,
        [0] * len(kl_loss) + [1] * len(kl_loss_ab),
    )[0]
    rec_roc, rec_pr = elog.get_classification_metrics(
        rec_loss + rec_loss_ab,
        [0] * len(rec_loss) + [1] * len(rec_loss_ab),
    )[0]
    loss_roc, loss_pr = elog.get_classification_metrics(
        test_loss + test_loss_ab,
        [0] * len(test_loss) + [1] * len(test_loss_ab),
    )[0]

    # vlog.show_value(np.mean(kl_roc), name="KL-loss", tag="ROC")
    # vlog.show_value(np.mean(rec_roc), name="Rec-loss", tag="ROC")
    # vlog.show_value(np.mean(loss_roc), name="Total-loss", tag="ROC")
    # elog.show_value(np.mean(kl_roc), name="KL-loss", tag="ROC")
    # elog.show_value(np.mean(rec_roc), name="Rec-loss", tag="ROC")
    # elog.show_value(np.mean(loss_roc), name="Total-loss", tag="ROC")

    # vlog.show_value(np.mean(kl_pr), name="KL-loss", tag="PR")
    # vlog.show_value(np.mean(rec_pr), name="Rec-loss", tag="PR")
    # vlog.show_value(np.mean(loss_pr), name="Total-loss", tag="PR")

    return kl_roc, rec_roc, loss_roc, kl_pr, rec_pr, loss_pr
Exemple #11
0

if __name__ == "__main__":

    WORKER_SIZE = 2
    BATCH_SIZE = 20

    kwargs = {'num_workers': WORKER_SIZE, 'pin_memory': True}
    train_loader = torch.utils.data.DataLoader(OPLoader(
        WORKER_SIZE, BATCH_SIZE),
                                               batch_size=WORKER_SIZE,
                                               shuffle=False,
                                               **kwargs)

    for batch_idx, (data, label) in enumerate(train_loader):
        data = data.flatten(0, 1)
        label = label.flatten(0, 1)

        time.sleep(2)
        print("BatchIDX: " + str(batch_idx), data.shape, label.shape)

        for i in range(0, data.shape[0]):
            img_viz = data.detach().cpu().numpy().copy()[i, 0, :, :]
            cv2.putText(img_viz, str(batch_idx), (0, 50),
                        cv2.FONT_HERSHEY_SIMPLEX, 2, 255)
            cv2.imshow("img", img_viz + 0.5)
            cv2.waitKey(15)
            break

        #break