Beispiel #1
0
def generate(output_directory, ckpt_path, ckpt_epoch, n, T, beta_0, beta_T,
             unet_config):
    """
    Generate images using the pretrained UNet model

    Parameters:

    output_directory (str):     output generated images to this path
    ckpt_path (str):            path of the checkpoints
    ckpt_epoch (int or 'max'):  the pretrained model checkpoint to be loaded; 
                                automitically selects the maximum epoch if 'max' is selected
    n (int):                    number of images to generate
    T (int):                    the number of diffusion steps
    beta_0 and beta_T (float):  diffusion parameters
    unet_config (dict):         dictionary of UNet parameters
    """

    # Compute diffusion hyperparameters
    Beta = torch.linspace(beta_0, beta_T, T).cuda()
    Alpha = 1 - Beta
    Alpha_bar = torch.ones(T).cuda()
    Beta_tilde = Beta + 0
    for t in range(T):
        Alpha_bar[t] *= Alpha[t] * Alpha_bar[t - 1] if t else Alpha[t]
        if t > 0:
            Beta_tilde[t] *= (1 - Alpha_bar[t - 1]) / (1 - Alpha_bar[t])
    Sigma = torch.sqrt(Beta_tilde)

    # Predefine model
    net = UNet(**unet_config).cuda()
    print_size(net)

    # Load checkpoint
    if ckpt_epoch == 'max':
        ckpt_epoch = find_max_epoch(ckpt_path, 'unet_ckpt')
    model_path = os.path.join(ckpt_path,
                              'unet_ckpt_' + str(ckpt_epoch) + '.pkl')
    try:
        checkpoint = torch.load(model_path, map_location='cpu')
        print('Model at epoch %s has been trained for %s seconds' %
              (ckpt_epoch, checkpoint['training_time_seconds']))
        net = UNet(**unet_config)
        net.load_state_dict(checkpoint['model_state_dict'])
        net = net.cuda()
    except:
        raise Exception('No valid model found')

    # Generation
    time0 = time.time()
    X_gen = sampling(net, (n, 3, 256, 256), T, Alpha, Alpha_bar, Sigma)
    print('generated %s samples at epoch %s in %s seconds' %
          (n, ckpt_epoch, int(time.time() - time0)))

    # Save generated images
    for i in range(n):
        save_image(rescale(X_gen[i]),
                   os.path.join(output_directory, 'img_{}.jpg'.format(i)))
    print('saved generated samples at epoch %s' % ckpt_epoch)
Beispiel #2
0
def generalTrain():
    # General training with final final testing set
    filename = "general_pred"
    data, indexes = util.loadGeneralTrainData()
    Xt = util.loadGeneralTestData()
    Xs, Ys = util.generalTransform(data, indexes)

    f = lambda a, b: np.concatenate((a, b), axis=0)
    X = reduce(f, Xs)
    Y = reduce(f, Ys)

    X, Y = util.sampling(X, Y)
    pcaX, pcaXt = util.pca(X, Xt)
    stdX, stdXt = util.scaler(pcaX, pcaXt)

    training(stdX, Y, stdXt, filename)
Beispiel #3
0
def generalCVTrain():
    # General Training with cross validation
    data, indexes = util.loadGeneralTrainData()
    Xs, Ys = util.generalTransform(data, indexes)
    subjectIndexes = [1, 4, 6, 9]

    print("\n--> General Training Data Set Shape")
    for i in range(4):
        # used to output ground truth files
        #util.saveGroundTruth(Ys[i], "groundTruth{0}.csv".format(subjectIndexes[i]))

        print("Subject {0}. X: {1}, Y: {2}, # of 1: {3}".format(\
        subjectIndexes[i], Xs[i].shape, Ys[i].shape, Ys[i].tolist().count(1)))

    for i in range(4):
        print("\n=========== Subject {0} As Validation Set ===========".format(subjectIndexes[i]))

        X, Y, Xt, Yt = util.splitTrainTest(Xs, Ys, i)
        X, Y = util.sampling(X, Y)
        pcaX, pcaXt = util.pca(X, Xt)
        stdX, stdXt = util.scaler(pcaX, pcaXt)

        sampleTraining(stdX, Y, stdXt, Yt)