Exemplo n.º 1
0
def main(noeval, **args):

    #args should be the info you need to specify the params
    # for a given experiment, but only params should be used below
    params = fill_params(**args)

    utils.set_gpus(params["gpus"])

    net = utils.create_network(**params)
    if not noeval:
        net.eval()

    utils.log_tagged_modules(params["modules_used"], params["log_dir"],
                             params["log_tag"], params["chkpt_num"])

    for dset in params["dsets"]:
        print(dset)

        fs = make_forward_scanner(dset, **params)

        output = forward.forward(net,
                                 fs,
                                 params["scan_spec"],
                                 activation=params["activation"])

        save_output(output, dset, **params)
Exemplo n.º 2
0
def main_fwd(noeval, **args):

    #args should be the info you need to specify the params
    # for a given experiment, but only params should be used below
    params = fill_params_fwd(**args)

    utils.set_gpus(params["gpus"])

    net = utils.create_network(**params)
    if not noeval:
        net.eval()

    utils.log_tagged_modules(params["modules_used"], params["log_dir"],
                             params["log_tag"], params["chkpt_num"])

    #lightsheet mods - input folder contains list of our "big" patches
    input_fld = os.path.join(params["data_dir"],
                             "input_patches")  #set directory
    output_fld = os.path.join(params["data_dir"],
                              "cnn_output")  #set patches directory

    if not os.path.exists(output_fld): os.mkdir(output_fld)
    jobid = 0  #for demo only

    #find files that need to be processed
    fls = [os.path.join(input_fld, xx) for xx in os.listdir(input_fld)]
    fls.sort()

    #select the file to process for this batch job
    if jobid > len(fls):
        #essentially kill job if too high - doing this to hopefully help with karma score although might not make a difference
        sys.stdout.write("\njobid {} > number of files {}\n".format(
            jobid, len(fls)))
        sys.stdout.flush()
    else:
        dset = fls[jobid]

        start = time.time()

        fs = make_forward_scanner(dset, **params)
        sys.stdout.write("\striding by: {}".format(fs.stride))
        sys.stdout.flush()

        output = forward.forward(
            net,
            fs,
            params["scan_spec"],  #runs forward pass
            activation=params["activation"])

        save_output(output, dset, output_fld, jobid, params["output_tag"],
                    params["chkpt_num"])  #saves tif
        fs._init()  #clear out scanner

    sys.stdout.write("\patch {}: {} min\n".format(
        jobid + 1, round((time.time() - start) / 60, 1)))
    sys.stdout.flush()
Exemplo n.º 3
0
def main(noeval, **args):

    #args should be the info you need to specify the params
    # for a given experiment, but only params should be used below
    params = fill_params(**args)

    utils.set_gpus(params["gpus"])

    net = utils.create_network(**params)
    if not noeval:
        net.eval()

    utils.log_tagged_modules(params["modules_used"], params["log_dir"],
                             params["log_tag"], params["chkpt_num"])

    #lightsheet mods - input folder contains list of our "big" patches
    input_fld = os.path.join(params["data_dir"],
                             "input_chnks")  #set patches directory
    sys.stdout.write("running inference on: \n{}\n".format(
        os.path.basename(params["data_dir"])))
    sys.stdout.flush()
    output_fld = os.path.join(params["data_dir"],
                              "output_chnks")  #set output directory

    jobid = int(params["jobid"])  #set patch no. to run through cnn

    #find files that need to be processed
    fls = [os.path.join(input_fld, xx) for xx in os.listdir(input_fld)]
    fls.sort()

    #select the file to process for this array job
    if jobid > len(fls) - 1:
        sys.stdout.write("\njobid {} > number of files {}".format(
            jobid, len(fls)))
        sys.stdout.flush()
    else:
        start = time.time()
        dset = fls[jobid]

        fs = make_forward_scanner(dset, **params)

        output = forward.forward(
            net,
            fs,
            params["scan_spec"],  #runs forward pass
            activation=params["activation"])

        save_output(output, dset, output_fld, **params)  #saves tif
        fs._init()  #clear out scanner

        sys.stdout.write("patch {}: {} min\n".format(
            jobid + 1, round((time.time() - start) / 60, 1)))
        sys.stdout.flush()
def main(**args):

    #args should be the info you need to specify the params
    # for a given experiment, but only params should be used below
    params = fill_params(**args)

    utils.set_gpus(params["gpus"])

    utils.make_required_dirs(**params)

    utils.log_tagged_modules(params["modules_used"],
                             params["log_dir"], "train",
                             chkpt_num=params["chkpt_num"])

    start_training(**params)
Exemplo n.º 5
0
def run(args):

    print('\nSettings: \n', args, '\n')

    args.model_signature = str(dt.datetime.now())[0:19].replace(' ', '_')
    args.model_signature = args.model_signature.replace(':', '_')

    ########## Find GPUs
    (gpu_config, n_gpu_used) = set_gpus(args.n_gpu)

    ########## Data, model, and optimizer setup
    mnist = MNIST(args)

    x = tf.placeholder(tf.float32, [None, 28, 28, 1])

    if args.model == 'hvae':
        if not args.K:
            raise ValueError('Must set number of flow steps when using HVAE')
        elif not args.temp_method:
            raise ValueError('Must set tempering method when using HVAE')
        model = HVAE(args, mnist.avg_logit)
    elif args.model == 'cnn':
        model = VAE(args, mnist.avg_logit)
    else:
        raise ValueError('Invalid model choice')

    elbo = model.get_elbo(x, args)
    nll = model.get_nll(x, args)

    optimizer = AdamaxOptimizer(learning_rate=args.learn_rate,
                                eps=args.adamax_eps)
    opt_step = optimizer.minimize(-elbo)

    ########## Tensorflow and saver setup
    sess = tf.Session(config=gpu_config)
    sess.run(tf.global_variables_initializer())

    saver = tf.train.Saver()
    savepath = os.path.join(args.checkpoint_dir, args.model_signature,
                            'model.ckpt')

    if not os.path.exists(args.checkpoint_dir):
        os.makedirs(args.checkpoint_dir)

    ########## Test that GPU memory is sufficient
    if n_gpu_used > 0:
        try:
            x_test = mnist.next_test_batch()
            (t_e, t_n) = sess.run((elbo, nll), {x: x_test})
            mnist.batch_idx_test = 0  # Reset batch counter if it works
        except:
            raise MemoryError("""
                Likely insufficient GPU memory
                Reduce test batch by lowering the -tbs parameter
                """)

    ########## Training Loop

    train_elbo_hist = []
    val_elbo_hist = []

    # For early stopping
    best_elbo = -np.inf
    es_epochs = 0
    epoch = 0

    train_times = []

    for epoch in range(1, args.epochs + 1):

        t0 = time.time()
        train_elbo = train(epoch, mnist, opt_step, elbo, x, args, sess)
        train_elbo_hist.append(train_elbo)
        train_times.append(time.time() - t0)
        print('One epoch took {:.2f} seconds'.format(time.time() - t0))

        val_elbo = validate(mnist, elbo, x, sess)
        val_elbo_hist.append(val_elbo)

        if val_elbo > best_elbo:

            # Save the model that currently generalizes best
            es_epochs = 0
            best_elbo = val_elbo
            saver.save(sess, savepath)
            best_model_epoch = epoch

        elif args.early_stopping_epochs > 0:

            es_epochs += 1

            if es_epochs >= args.early_stopping_epochs:
                print('***** STOPPING EARLY ON EPOCH {} of {} *****'.format(
                    epoch, args.epochs))
                break

        print('--> Early stopping: {}/{} (Best ELBO: {:.4f})'.format(
            es_epochs, args.early_stopping_epochs, best_elbo))
        print('\t Current val ELBO: {:.4f}\n'.format(val_elbo))

        if np.isnan(val_elbo):
            raise ValueError('NaN encountered!')

    train_times = np.array(train_times)
    mean_time = np.mean(train_times)
    std_time = np.std(train_times)
    print('Average train time per epoch: {:.2f} +/- {:.2f}'.format(
        mean_time, std_time))

    ########## Evaluation

    # Restore the best-performing model
    saver.restore(sess, savepath)

    test_elbos = np.zeros(args.n_nll_runs)
    test_nlls = np.zeros(args.n_nll_runs)

    for i in range(args.n_nll_runs):

        print('\n---- Test run {} of {} ----\n'.format(i + 1, args.n_nll_runs))
        (test_elbos[i], test_nlls[i]) = evaluate(mnist, elbo, nll, x, args,
                                                 sess)

    mean_elbo = np.mean(test_elbos)
    std_elbo = np.std(test_elbos)

    mean_nll = np.mean(test_nlls)
    std_nll = np.std(test_nlls)

    print('\nTest ELBO: {:.2f} +/- {:.2f}'.format(mean_elbo, std_elbo))
    print('Test NLL: {:.2f} +/- {:.2f}'.format(mean_nll, std_nll))

    ########## Logging, Saving, and Plotting

    with open(args.logfile, 'a') as ff:
        print('----------------- Test ID {} -----------------'.format(
            args.model_signature),
              file=ff)
        print(args, file=ff)
        print('Stopped after {} epochs'.format(epoch), file=ff)
        print('Best model from epoch {}'.format(best_model_epoch), file=ff)
        print('Average train time per epoch: {:.2f} +/- {:.2f}'.format(
            mean_time, std_time),
              file=ff)

        print('FINAL VALIDATION ELBO: {:.2f}'.format(val_elbo_hist[-1]),
              file=ff)
        print('Test ELBO: {:.2f} +/- {:.2f}'.format(mean_elbo, std_elbo),
              file=ff)
        print('Test NLL: {:.2f} +/- {:.2f}\n'.format(mean_nll, std_nll),
              file=ff)

    if not os.path.exists(args.pickle_dir):
        os.makedirs(args.pickle_dir)

    train_dict = {
        'train_elbo': train_elbo_hist,
        'val_elbo': val_elbo_hist,
        'args': args
    }
    pickle.dump(
        train_dict,
        open(os.path.join(args.pickle_dir, args.model_signature + '.p'), 'wb'))

    if not os.path.exists(args.plot_dir):
        os.makedirs(args.plot_dir)

    tf_gen_samples = model.get_samples(args)
    np_gen_samples = sess.run(tf_gen_samples)
    plot_digit_samples(np_gen_samples, args)

    plot_training_curve(train_elbo_hist, val_elbo_hist, args)

    ########## Email notification upon test completion

    try:

        msg_text = """Test completed for ID {0}.

        Parameters: {1}

        Test ELBO: {2:.2f} +/- {3:.2f}
        Test NLL: {4:.2f} +/- {5:.2f} """.format(args.model_signature, args,
                                                 mean_elbo, std_elbo, mean_nll,
                                                 std_nll)

        msg = MIMEText(msg_text)
        msg['Subject'] = 'Test ID {0} Complete'.format(args.model_signature)
        msg['To'] = args.receiver
        msg['From'] = args.sender

        s = smtplib.SMTP('localhost')
        s.sendmail(args.sender, [args.receiver], msg.as_string())
        s.quit()

    except:

        print('Unable to send email from sender {0} to receiver {1}'.format(
            args.sender, args.receiver))