示例#1
0
def exp_cevae(model="dlvm",
              n=1000,
              d=3,
              p=100,
              prop_miss=0.1,
              citcio=False,
              seed=0,
              d_cevae=20,
              n_epochs=402,
              method="glm",
              **kwargs):

    # import here because of differents sklearn version used
    from cevae_tf import cevae_tf
    from sklearn.preprocessing import Imputer

    if model == "lrmf":
        Z, X, w, y, ps = gen_lrmf(n=n,
                                  d=d,
                                  p=p,
                                  citcio=citcio,
                                  prop_miss=prop_miss,
                                  seed=seed)
    elif model == "dlvm":
        Z, X, w, y, ps = gen_dlvm(n=n,
                                  d=d,
                                  p=p,
                                  citcio=citcio,
                                  prop_miss=prop_miss,
                                  seed=seed)
    else:
        raise NotImplementedError(
            "Other data generating models not implemented here yet.")

    X_miss = ampute(X, prop_miss=prop_miss, seed=seed)
    X_imp = Imputer().fit_transform(X_miss)

    y0_hat, y1_hat = cevae_tf(X_imp, w, y, d_cevae=d_cevae, n_epochs=n_epochs)

    # Tau estimated on Zhat=E[Z|X]
    ps_hat = np.ones(len(y0_hat)) / 2
    # res_tau_ols = tau_ols(zhat, w, y)
    # res_tau_ols_ps = tau_ols_ps(zhat, w, y)
    #res_tau_dr = tau_dr(y, w, y0_hat, y1_hat, ps_hat, method)
    #res_tau_dr_true_ps = tau_dr(y, w, y0_hat, y1_hat, ps, method)

    res_tau = np.mean(y1_hat - y0_hat)

    return res_tau
示例#2
0
def exp_mi(model="dlvm",
           n=1000,
           d=3,
           p=100,
           prop_miss=0.1,
           citcio=False,
           seed=0,
           m=10,
           d_cevae=20,
           n_epochs=402,
           method="glm",
           **kwargs):

    if model == "lrmf":
        Z, X, w, y, ps = gen_lrmf(n=n,
                                  d=d,
                                  p=p,
                                  citcio=citcio,
                                  prop_miss=prop_miss,
                                  seed=seed)
    elif model == "dlvm":
        Z, X, w, y, ps = gen_dlvm(n=n,
                                  d=d,
                                  p=p,
                                  citcio=citcio,
                                  prop_miss=prop_miss,
                                  seed=seed)
    else:
        raise NotImplementedError(
            "Other data generating models not implemented here yet.")

    X_miss = ampute(X, prop_miss=prop_miss, seed=seed)

    tau_dr_mi, tau_ols_mi, tau_ols_ps_mi, tau_resid_mi = tau_mi(X_miss,
                                                                w,
                                                                y,
                                                                m=m,
                                                                method=method)

    return tau_dr_mi, tau_ols_mi, tau_ols_ps_mi, tau_resid_mi
示例#3
0
def main(unused_argv):
    # Data generating process parameters
    exp_parameter_grid = {
        'model': ["dlvm", "lrmf"] if FLAGS.model is None else [FLAGS.model],
        'citcio': [
            False,
        ],
        'nuisance': [
            True,
        ],
        'n': [500, 1000, 5000, 10000]
        if FLAGS.n_observations is None else [FLAGS.n_observations],
        'p':
        [5, 10, 50, 100] if FLAGS.p_ambient is None else [FLAGS.p_ambient],
        'y_snr': [5.] if FLAGS.y_snr is None else [FLAGS.y_snr],
        'x_snr': [2.] if FLAGS.x_snr is None else [FLAGS.x_snr],
        'mu_z': [0.] if FLAGS.mu_z is None else [FLAGS.mu_z],
        'sig_z': [1.] if FLAGS.sig_z is None else [FLAGS.sig_z],
        'sig_xgivenz':
        [0.001] if FLAGS.sig_xgivenz is None else [FLAGS.sig_xgivenz],
        'prop_miss':
        [0.0, 0.1, 0.3, 0.5] if FLAGS.prop_miss is None else [FLAGS.prop_miss],
        'regularize':
        [False] if FLAGS.regularize is None else [FLAGS.regularize],
        'seed':
        np.arange(FLAGS.n_seeds),
    }
    range_d_over_p = [
        0.002, 0.01, 0.1
    ] if FLAGS.d_over_p is None and FLAGS.d_latent is None else [
        FLAGS.d_over_p
    ]
    range_d = None if range_d_over_p is not None and FLAGS.d_latent is None else FLAGS.d_latent

    # MDC parameters
    range_d_offset = [0, 5] if FLAGS.miwae_d_offset is None else [
        FLAGS.miwae_d_offset
    ]

    mdc_parameter_grid = {
        'mu_prior':
        [0.] if FLAGS.miwae_mu_prior is None else [FLAGS.miwae_mu_prior],
        'sig_prior':
        [1.] if FLAGS.miwae_sig_prior is None else [FLAGS.miwae_sig_prior],
        'num_samples_zmul': [500] if FLAGS.miwae_n_samples_zmul is None else
        [FLAGS.miwae_n_samples_zmul],
        'learning_rate': [
            0.0001,
        ]
        if FLAGS.miwae_learning_rate is None else [FLAGS.miwae_learning_rate],
        'n_epochs': [
            5000,
        ] if FLAGS.miwae_n_epochs is None else [FLAGS.miwae_n_epochs],
    }

    # MI parameters
    range_m = [
        10,
    ] if FLAGS.n_imputations is None else [FLAGS.n_imputations]

    # Experiment and output file name
    output = f'results/{FLAGS.exp_name}.csv' if FLAGS.output is None else FLAGS.output

    FLAGS.log_dir = './sessions/logging/' if FLAGS.log_path is None else FLAGS.log_path
    logging.get_absl_handler().use_absl_log_file()

    logging.info('*' * 20)
    logging.info(f'Starting exp: {FLAGS.exp_name}')
    logging.info('*' * 20)

    exp_arguments = [
        dict(zip(exp_parameter_grid.keys(), vals))
        for vals in itertools.product(*exp_parameter_grid.values())
    ]

    previous_runs = set()
    if tf.io.gfile.exists(output):
        with tf.io.gfile.GFile(output, mode='r') as f:
            reader = csv.DictReader(f)
            for row in reader:
                # Note: we need to do this conversion because DictReader creates an
                # OrderedDict, and reads all values as str instead of bool or int.
                previous_runs.add(
                    str({
                        'model': row['model'],
                        'citcio': row['citcio'] == 'True',
                        'n': int(row['n']),
                        'p': int(row['p']),
                        'y_snr': float(row['y_snr']),
                        'x_snr': float(row['x_snr']),
                        'mu_z': float(row['mu_z']),
                        'sig_z': float(row['sig_z']),
                        'prop_miss': float(row['prop_miss']),
                        'regularize': row['regularize'] == 'True',
                        'seed': int(row['seed']),
                        'd': int(row['d']),
                        'sig_xgivenz': float(row['sig_xgivenz'])
                    }))
    logging.info('Previous runs')
    logging.info(previous_runs)

    for args in exp_arguments:
        ## For given p, create range for d such that 1 < d < p
        ## starting with given ratios for d/p
        if range_d is None:
            range_d = [
                np.maximum(2, int(np.floor(args['p'] * x)))
                for x in range_d_over_p
            ]
            range_d = np.unique(
                np.array(range_d)[np.array(range_d) < args['p']].tolist())

        exp_time = time.time()
        for args['d'] in range_d:
            # We only consider cases where latent dimension <= ambient dimension
            if args['d'] > args['p']:
                continue
            res = []

            if str(args) in previous_runs:
                logging.info(f'Skipped {args}')
                continue
            else:
                logging.info(f'running exp with {args}')

            if args['model'] == "lrmf":
                Z, X, w, y, ps, mu0, mu1 = gen_lrmf(
                    n=args['n'],
                    d=args['d'],
                    p=args['p'],
                    y_snr=args['y_snr'],
                    x_snr=args['x_snr'],
                    citcio=args['citcio'],
                    prop_miss=args['prop_miss'],
                    seed=args['seed'],
                    sig_xgivenz=args['sig_xgivenz'])
            elif args['model'] == "dlvm":
                Z, X, w, y, ps, mu0, mu1 = gen_dlvm(
                    n=args['n'],
                    d=args['d'],
                    p=args['p'],
                    y_snr=args['y_snr'],
                    citcio=args['citcio'],
                    prop_miss=args['prop_miss'],
                    seed=args['seed'],
                    mu_z=args['mu_z'],
                    sig_z=args['sig_z'],
                    x_snr=args['x_snr'],
                    sig_xgivenz=args['sig_xgivenz'])

            X_miss = ampute(X, prop_miss=args['prop_miss'], seed=args['seed'])

            # On complete data
            t0 = time.time()
            if args['nuisance']:
                tau, nu = exp_complete(Z, X, w, y, args['regularize'],
                                       args['nuisance'])
            else:
                tau = exp_complete(Z, X, w, y, args['regularize'],
                                   args['nuisance'])
            args['time'] = int(time.time() - t0)
            row = {'Method': 'Z'}
            row.update(args)
            row.update(tau['Z'])
            print(tau['Z'])
            if args['nuisance']:
                row.update(
                    {'ps_hat_mse': mean_squared_error(ps, nu['Z']['ps_hat'])})
                row.update(
                    {'y0_hat_mse': mean_squared_error(mu0, nu['Z']['y0_hat'])})
                row.update(
                    {'y1_hat_mse': mean_squared_error(mu1, nu['Z']['y1_hat'])})
            res.append(row)
            row = {'Method': 'X'}
            row.update(args)
            row.update(tau['X'])
            if args['nuisance']:
                row.update(
                    {'ps_hat_mse': mean_squared_error(ps, nu['X']['ps_hat'])})
                row.update(
                    {'y0_hat_mse': mean_squared_error(mu0, nu['X']['y0_hat'])})
                row.update(
                    {'y1_hat_mse': mean_squared_error(mu1, nu['X']['y1_hat'])})
            res.append(row)

            # Mean-imputation
            t0 = time.time()
            if args['nuisance']:
                tau, nu = exp_mean(X_miss, w, y, args['regularize'],
                                   args['nuisance'])
            else:
                tau = exp_mean(X_miss, w, y, args['regularize'])
            args['time'] = int(time.time() - t0)
            row = {'Method': 'Mean_imp'}
            row.update(args)
            row.update(tau)
            if args['nuisance']:
                row.update(
                    {'ps_hat_mse': mean_squared_error(ps, nu['ps_hat'])})
                row.update(
                    {'y0_hat_mse': mean_squared_error(mu0, nu['y0_hat'])})
                row.update(
                    {'y1_hat_mse': mean_squared_error(mu1, nu['y1_hat'])})
            res.append(row)

            # Multiple imputation
            for m in range_m:
                t0 = time.time()
                if args['nuisance']:
                    tau, nu = exp_mi(X_miss,
                                     w,
                                     y,
                                     regularize=args['regularize'],
                                     m=m,
                                     nuisance=args['nuisance'])
                else:
                    tau = exp_mi(X_miss,
                                 w,
                                 y,
                                 regularize=args['regularize'],
                                 m=m)
                args['time'] = int(time.time() - t0)
                row = {'Method': 'MI', 'm': m}
                row.update(args)
                row.update(tau)
                if args['nuisance']:
                    row.update(
                        {'ps_hat_mse': mean_squared_error(ps, nu['ps_hat'])})
                    row.update(
                        {'y0_hat_mse': mean_squared_error(mu0, nu['y0_hat'])})
                    row.update(
                        {'y1_hat_mse': mean_squared_error(mu1, nu['y1_hat'])})
                res.append(row)

            # Matrix Factorization
            t0 = time.time()
            if args['nuisance']:
                tau, nu, r, zhat = exp_mf(X_miss,
                                          w,
                                          y,
                                          args['regularize'],
                                          args['nuisance'],
                                          return_zhat=True)
            else:
                tau, r = exp_mf(X_miss, w, y, args['regularize'])
            args['time'] = int(time.time() - t0)
            row = {'Method': 'MF', 'r': r}
            row.update(args)
            row.update(tau)
            if args['nuisance']:
                row.update(
                    {'ps_hat_mse': mean_squared_error(ps, nu['ps_hat'])})
                row.update(
                    {'y0_hat_mse': mean_squared_error(mu0, nu['y0_hat'])})
                row.update(
                    {'y1_hat_mse': mean_squared_error(mu1, nu['y1_hat'])})
            res.append(row)

            # MissDeepCausal
            mdc_parameter_grid['d_miwae'] = [
                args['d'] + x for x in range_d_offset
            ]

            mdc_arguments = [
                dict(zip(mdc_parameter_grid.keys(), vals))
                for vals in itertools.product(*mdc_parameter_grid.values())
            ]

            for mdc_arg in mdc_arguments:
                t0 = time.time()
                mdc_arg['mu_prior'] = args['mu_z']
                session_file = './sessions/' + \
                                    args['model'] + '_'+ \
                                    '_sigXgivenZ' + str(args['sig_xgivenz']) + \
                                    '_n' + str(args['n']) + \
                                    '_p' + str(args['p']) + \
                                    '_d' + str(args['d']) + \
                                    '_ysnr' + str(args['y_snr']) +\
                                    '_xsnr' + str(args['x_snr']) +\
                                    '_propNA' + str(args['prop_miss']) + \
                                    '_seed' + str(args['seed'])
                session_file_complete = session_file + \
                                        '_dmiwae' + str(mdc_arg['d_miwae']) + \
                                        '_sigprior' + str(mdc_arg['sig_prior'])
                if args['nuisance']:
                    tau, nu, elbo, zhat, zhat_mul = exp_mdc(
                        X_miss,
                        w,
                        y,
                        d_miwae=mdc_arg['d_miwae'],
                        mu_prior=mdc_arg['mu_prior'],
                        sig_prior=mdc_arg['sig_prior'],
                        num_samples_zmul=mdc_arg['num_samples_zmul'],
                        learning_rate=mdc_arg['learning_rate'],
                        n_epochs=mdc_arg['n_epochs'],
                        regularize=args['regularize'],
                        nuisance=args['nuisance'],
                        return_zhat=True,
                        save_session=True,
                        session_file=session_file,
                        session_file_complete=session_file_complete)
                else:
                    tau, elbo, zhat, zhat_mul = exp_mdc(
                        X_miss,
                        w,
                        y,
                        d_miwae=mdc_arg['d_miwae'],
                        mu_prior=mdc_arg['mu_prior'],
                        sig_prior=mdc_arg['sig_prior'],
                        num_samples_zmul=mdc_arg['num_samples_zmul'],
                        learning_rate=mdc_arg['learning_rate'],
                        n_epochs=mdc_arg['n_epochs'],
                        regularize=args['regularize'],
                        return_zhat=True,
                        save_session=True,
                        session_file=session_file,
                        session_file_complete=session_file_complete)

                args['training_time'] = int(time.time() - t0)
                row = {'Method': 'MDC.process', 'elbo': elbo}
                row.update(args)
                row.update(mdc_arg)
                row.update(tau['MDC.process'])
                if args['nuisance']:
                    row.update({
                        'ps_hat_mse':
                        mean_squared_error(ps, nu['MDC.process']['ps_hat'])
                    })
                    row.update({
                        'y0_hat_mse':
                        mean_squared_error(mu0, nu['MDC.process']['y0_hat'])
                    })
                    row.update({
                        'y1_hat_mse':
                        mean_squared_error(mu1, nu['MDC.process']['y1_hat'])
                    })
                res.append(row)
                row = {'Method': 'MDC.mi', 'elbo': elbo}
                row.update(args)
                row.update(mdc_arg)
                row.update(tau['MDC.mi'])
                if args['nuisance']:
                    row.update({
                        'ps_hat_mse':
                        mean_squared_error(ps, nu['MDC.mi']['ps_hat'])
                    })
                    row.update({
                        'y0_hat_mse':
                        mean_squared_error(mu0, nu['MDC.mi']['y0_hat'])
                    })
                    row.update({
                        'y1_hat_mse':
                        mean_squared_error(mu1, nu['MDC.mi']['y1_hat'])
                    })
                res.append(row)

            log_res(output, res, ['Method'] + list(args.keys()) +
                    l_method_params + l_tau + l_nu)
            logging.info('........... DONE')
            logging.info(f'in {time.time() - exp_time} s \n\n')

    logging.info('*' * 20)
    logging.info(f'Exp: {FLAGS.exp_name} succesfully ended.')
    logging.info('*' * 20)
示例#4
0
def main(unused_argv):
  # Data generating process parameters
  exp_parameter_grid = {
     'model': ["dlvm", "lrmf"] if FLAGS.model is None else [FLAGS.model],
     'citcio': [False, ],
     'n': [1000, 10000, 100000] if FLAGS.n_observations is None else [FLAGS.n_observations],
     'p': [10, 100, 1000] if FLAGS.p_ambient is None else [FLAGS.p_ambient],
     'y_snr': [5.] if FLAGS.y_snr is None else [FLAGS.y_snr],
     'x_snr': 1.*np.arange(2,20,4) if FLAGS.x_snr is None else [FLAGS.x_snr],
     'mu_z': [0.] if FLAGS.mu_z is None else [FLAGS.mu_z],
     'sig_z': [1.] if FLAGS.sig_z is None else [FLAGS.sig_z],
     'sig_xgivenz': ["fixed", ] if FLAGS.sig_xgivenz is None else [FLAGS.sig_xgivenz],
     'prop_miss': [0.0, 0.1, 0.3, 0.5] if FLAGS.prop_miss is None else [FLAGS.prop_miss],
     'regularize': [False] if FLAGS.regularize is None else [FLAGS.regularize],
     'seed': np.arange(FLAGS.n_seeds),
  }
  range_d_over_p = [0.002, 0.01, 0.1] if FLAGS.d_over_p is None and FLAGS.d_latent is None else [FLAGS.d_over_p]
  range_d = None if range_d_over_p is not None and FLAGS.d_latent is None else FLAGS.d_latent

  # MDC parameters
  range_d_offset = [0, 5, 10] if FLAGS.miwae_d_offset is None else [FLAGS.miwae_d_offset]

  mdc_parameter_grid = {
     'mu_prior': [0.] if FLAGS.miwae_mu_prior is None else [FLAGS.miwae_mu_prior],
     'sig_prior': [1.] if FLAGS.miwae_sig_prior is None else [FLAGS.miwae_sig_prior],
     'num_samples_zmul': [500] if FLAGS.miwae_n_samples_zmul is None else [FLAGS.miwae_n_samples_zmul],
     'learning_rate': [0.0001,] if FLAGS.miwae_learning_rate is None else [FLAGS.miwae_learning_rate],
     'n_epochs': [5000,] if FLAGS.miwae_n_epochs is None else [FLAGS.miwae_n_epochs],
  }



  # Experiment and output file name
  output = f'results/{FLAGS.exp_name}.csv' if FLAGS.output is None else FLAGS.output

  FLAGS.log_dir = './sessions/logging/' if FLAGS.log_path is None else FLAGS.log_path
  logging.get_absl_handler().use_absl_log_file()
  logging.info('*'*20)
  logging.info(f'Starting exp: {FLAGS.exp_name}')
  logging.info('*'*20)

  exp_arguments = [dict(zip(exp_parameter_grid.keys(), vals))
                  for vals in itertools.product(*exp_parameter_grid.values())]

  previous_runs = set()
  if tf.io.gfile.exists(output):
    with tf.io.gfile.GFile(output, mode='r') as f:
        reader = csv.DictReader(f)
        for row in reader:
          # Note: we need to do this conversion because DictReader creates an
          # OrderedDict, and reads all values as str instead of bool or int.
          previous_runs.add(str({
             'model': row['model'],
             'citcio': row['citcio'] == 'True',
             'n': int(row['n']),
             'p': int(row['p']),
             'y_snr': float(row['y_snr']),
             'x_snr': float(row['x_snr']),
             'mu_z': float(row['mu_z']),
             'sig_z': float(row['sig_z']),
             'prop_miss': float(row['prop_miss']),
             'regularize': row['regularize'] == 'True',
             'seed': int(row['seed']),
             'd': int(row['d']),
             'sig_xgivenz': row['sig_xgivenz']
          }))
  logging.info('Previous runs')
  logging.info(previous_runs)

  for args in exp_arguments:
    # For given p, if range_d is not yet specified,
    # create range for d such that 1 < d < p
    # starting with given ratios for d/p
    if range_d is None:
        range_d = [np.maximum(2, int(np.floor(args['p']*x))) for x in range_d_over_p]
        range_d = np.unique(np.array(range_d)[np.array(range_d)<args['p']].tolist())

    exp_time = time.time()
    for args['d'] in range_d:
      # We only consider cases where latent dimension <= ambient dimension
      if args['d'] > args['p']:
          continue
      res = []

      if str(args) in previous_runs:
        logging.info(f'Skipped {args}')
        continue
      else:
        logging.info(f'running exp with {args}')

      if args['model'] == "lrmf":
        Z, X, w, y, ps, mu0, mu1 = gen_lrmf(n=args['n'], d=args['d'], p=args['p'],
                                            y_snr=args['y_snr'], x_snr=args['x_snr'],
                                            citcio=args['citcio'],
                                            prop_miss=args['prop_miss'],
                                            seed=args['seed'])
      elif args['model'] == "dlvm":
        Z, X, w, y, ps, mu0, mu1 = gen_dlvm(n=args['n'], d=args['d'], p=args['p'],
                                            y_snr=args['y_snr'], citcio=args['citcio'],
                                            prop_miss=args['prop_miss'],
                                            seed=args['seed'],
                                            mu_z=args['mu_z'],
                                            sig_z=args['sig_z'],
                                            x_snr=args['x_snr'],
                                            sig_xgivenz=args['sig_xgivenz'])

      X_miss = ampute(X, prop_miss = args['prop_miss'], seed = args['seed'])


      # MIWAE
      mdc_parameter_grid['d_miwae'] = [args['d']+x for x in range_d_offset]

      mdc_arguments = [dict(zip(mdc_parameter_grid.keys(), vals))
                       for vals in itertools.product(*mdc_parameter_grid.values())]

      for mdc_arg in mdc_arguments:
          t0 = time.time()
          mdc_arg['mu_prior']=args['mu_z']
          session_file = './sessions/' + \
                              args['model'] + '_'+ \
                              args['sig_xgivenz'] + 'Sigma'+ \
                              '_n' + str(args['n']) + \
                              '_p' + str(args['p']) + \
                              '_d' + str(args['d']) + \
                              '_ysnr' + str(args['y_snr']) +\
                              '_xsnr' + str(args['x_snr']) +\
                              '_propNA' + str(args['prop_miss']) + \
                              '_seed' + str(args['seed'])
          session_file_complete = session_file + \
                                  '_dmiwae' + str(mdc_arg['d_miwae']) + \
                                  '_sigprior' + str(mdc_arg['sig_prior'])
          epochs=-1
          tmp = glob.glob(session_file_complete+'.*')
          sess = tf.Session(graph=tf.reset_default_graph())
          if len(tmp)>0:
            continue
          else:
            xhat, zhat, zhat_mul, elbo, epochs = miwae_es(X_miss,
                                                          d_miwae=mdc_arg['d_miwae'],
                                                          mu_prior=mdc_arg['mu_prior'],
                                                          sig_prior=mdc_arg['sig_prior'],
                                                          num_samples_zmul=mdc_arg['num_samples_zmul'],
                                                          l_rate=mdc_arg['learning_rate'],
                                                          n_epochs=mdc_arg['n_epochs'],
                                                          save_session = True,
                                                          session_file = session_file)
            with open(session_file_complete + '.pkl', 'wb') as file_data:  # Python 3: open(..., 'wb')
              pickle.dump([xhat, zhat, zhat_mul, elbo, epochs], file_data)


      logging.info('........... DONE')
      logging.info(f'in {time.time() - exp_time} s \n\n')

  logging.info('*'*20)
  logging.info(f'Exp: {FLAGS.exp_name} successfully ended.')
  logging.info('*'*20)
l_tau = ['tau_dr', 'tau_ols', 'tau_ols_ps', 'tau_resid']
output = '../results/'+exp_name+'.csv'
l_scores = []

for args['model'] in range_model:
    for args['citcio'] in range_citcio:
        for args['n'] in range_n:
            for args['p'] in range_p:
                range_d = [int(np.floor(args['p']*x)) for x in range_d_over_p]
                for args['d'] in range_d:
                    for args['prop_miss'] in range_prop_miss:
                        for args['seed'] in range_seed:
                            print(args)
                            if args['model'] == "lrmf":
                                Z, X, w, y, ps = gen_lrmf(n=args['n'], d=args['d'], p=args['p'], 
                                                          citcio = args['citcio'], prop_miss = args['prop_miss'], 
                                                          seed = args['seed'])
                            elif args['model'] == "dlvm":
                                Z, X, w, y, ps = gen_dlvm(n=args['n'], d=args['d'], p=args['p'], 
                                                          citcio = args['citcio'], prop_miss = args['prop_miss'], 
                                                          seed = args['seed'])
                            
                            X_miss = ampute(X, prop_miss = args['prop_miss'], seed = args['seed'])

                            # Complete
                            t0 = time.time()
                            tau = exp_complete(Z, X, w, y)
                            args['time'] = int(time.time() - t0)
                            l_scores.append(np.concatenate((['Z'], list(args.values()), [None]*7, tau['Z'])))
                            l_scores.append(np.concatenate((['X'], list(args.values()), [None]*7, tau['X'])))
                            
示例#6
0
def main(unused_argv):

    # Data generating process parameters
    exp_parameter_grid = {
        'model': ["dlvm", "lrmf"] if FLAGS.model is None else [FLAGS.model],
        'citcio': [
            False,
        ],
        'n': [1000, 10000, 100000]
        if FLAGS.n_observations is None else [FLAGS.n_observations],
        'p': [10, 100, 1000] if FLAGS.p_ambient is None else [FLAGS.p_ambient],
        'snr': [1., 5., 10.] if FLAGS.snr is None else [FLAGS.snr],
        'prop_miss': [0.0, 0.1, 0.3, 0.5, 0.7, 0.9]
        if FLAGS.prop_miss is None else [FLAGS.prop_miss],
        'regularize':
        [False, True] if FLAGS.regularize is None else [FLAGS.regularize],
        'seed':
        np.arange(FLAGS.n_seeds),
    }
    range_d_over_p = [0.002, 0.01, 0.1
                      ] if FLAGS.d_over_p is None else [FLAGS.d_over_p]

    # MDC parameters
    range_d_offset = [0, 5, 10] if FLAGS.miwae_d_offset is None else [
        FLAGS.miwae_d_offset
    ]

    mdc_parameter_grid = {
        'sig_prior': [0.1, 1, 10]
        if FLAGS.miwae_sig_prior is None else [FLAGS.miwae_sig_prior],
        'num_samples_zmul': [50, 500] if FLAGS.miwae_n_samples_zmul is None
        else [FLAGS.miwae_n_samples_zmul],
        'learning_rate': [
            0.0001,
        ]
        if FLAGS.miwae_learning_rate is None else [FLAGS.miwae_learning_rate],
        'n_epochs': [
            500,
        ] if FLAGS.miwae_n_epochs is None else [FLAGS.miwae_n_epochs],
    }

    # MI parameters
    range_m = [10, 20, 50
               ] if FLAGS.n_imputations is None else [FLAGS.n_imputations]

    # Experiment and output file name
    output = f'results/{FLAGS.exp_name}.csv' if FLAGS.output is None else FLAGS.output

    logging.info('*' * 20)
    logging.info(f'Starting exp: {FLAGS.exp_name}')
    logging.info('*' * 20)

    exp_arguments = [
        dict(zip(exp_parameter_grid.keys(), vals))
        for vals in itertools.product(*exp_parameter_grid.values())
    ]

    previous_runs = set()
    if tf.io.gfile.exists(output):
        with tf.io.gfile.GFile(output, mode='r') as f:
            reader = csv.DictReader(f)
            for row in reader:
                # Note: we need to do this conversion because DictReader creates an
                # OrderedDict, and reads all values as str instead of bool or int.
                previous_runs.add(
                    str({
                        'model': row['model'],
                        'citcio': row['citcio'] == 'True',
                        'n': int(row['n']),
                        'p': int(row['p']),
                        'snr': float(row['snr']),
                        'prop_miss': float(row['prop_miss']),
                        'regularize': row['regularize'] == 'True',
                        'seed': int(row['seed']),
                        'd': int(row['d']),
                    }))
    logging.info('Previous runs')
    logging.info(previous_runs)

    for args in exp_arguments:
        # For given p, create range for d such that 1 < d < p
        # starting with given ratios for d/p
        range_d = [
            np.maximum(2, int(np.floor(args['p'] * x))) for x in range_d_over_p
        ]
        range_d = np.unique(
            np.array(range_d)[np.array(range_d) < args['p']].tolist())
        exp_time = time.time()
        for args['d'] in range_d:
            res = []

            if str(args) in previous_runs:
                logging.info(f'Skipped {args}')
                continue
            else:
                logging.info(f'running exp with {args}')

            if args['model'] == "lrmf":
                Z, X, w, y, ps = gen_lrmf(n=args['n'],
                                          d=args['d'],
                                          p=args['p'],
                                          y_snr=args['snr'],
                                          citcio=args['citcio'],
                                          prop_miss=args['prop_miss'],
                                          seed=args['seed'])
            elif args['model'] == "dlvm":
                Z, X, w, y, ps = gen_dlvm(n=args['n'],
                                          d=args['d'],
                                          p=args['p'],
                                          y_snr=args['snr'],
                                          citcio=args['citcio'],
                                          prop_miss=args['prop_miss'],
                                          seed=args['seed'])

            X_miss = ampute(X, prop_miss=args['prop_miss'], seed=args['seed'])

            # On complete data
            t0 = time.time()
            tau = exp_complete(Z, X, w, y, args['regularize'])
            args['time'] = int(time.time() - t0)
            row = {'Method': 'Z'}
            row.update(args)
            row.update(tau['Z'])
            res.append(row)
            row = {'Method': 'X'}
            row.update(args)
            row.update(tau['X'])
            res.append(row)

            # Mean-imputation
            t0 = time.time()
            tau = exp_mean(X_miss, w, y, args['regularize'])
            args['time'] = int(time.time() - t0)
            row = {'Method': 'Mean_imp'}
            row.update(args)
            row.update(tau)
            res.append(row)

            # Multiple imputation
            for m in range_m:
                t0 = time.time()
                tau = exp_mi(X_miss, w, y, regularize=args['regularize'], m=m)
                args['time'] = int(time.time() - t0)
                row = {'Method': 'MI', 'm': m}
                row.update(args)
                row.update(tau)
                res.append(row)

            # Matrix Factorization
            t0 = time.time()
            tau, r = exp_mf(X_miss, w, y, args['regularize'])
            args['time'] = int(time.time() - t0)
            row = {'Method': 'MF', 'r': r}
            row.update(args)
            row.update(tau)
            res.append(row)

            # MissDeepCausal
            mdc_parameter_grid['d_miwae'] = [
                args['d'] + x for x in range_d_offset
            ]

            mdc_arguments = [
                dict(zip(mdc_parameter_grid.keys(), vals))
                for vals in itertools.product(*mdc_parameter_grid.values())
            ]

            for mdc_arg in mdc_arguments:
                t0 = time.time()
                tau, elbo = exp_mdc(
                    X_miss,
                    w,
                    y,
                    d_miwae=mdc_arg['d_miwae'],
                    sig_prior=mdc_arg['sig_prior'],
                    num_samples_zmul=mdc_arg['num_samples_zmul'],
                    learning_rate=mdc_arg['learning_rate'],
                    n_epochs=mdc_arg['n_epochs'],
                    regularize=args['regularize'])
                args['time'] = int(time.time() - t0)
                row = {'Method': 'MDC.process', 'elbo': elbo}
                row.update(args)
                row.update(mdc_arg)
                row.update(tau['MDC.process'])
                res.append(row)
                row = {'Method': 'MDC.mi', 'elbo': elbo}
                row.update(args)
                row.update(mdc_arg)
                row.update(tau['MDC.mi'])
                res.append(row)

            log_res(output, res,
                    ['Method'] + list(args.keys()) + l_method_params + l_tau)
            logging.info('........... DONE')
            logging.info(f'in {time.time() - exp_time} s \n\n')

    logging.info('*' * 20)
    logging.info(f'Exp: {FLAGS.exp_name} succesfully ended.')
    logging.info('*' * 20)
示例#7
0
def exp_miwae(model="dlvm",
              n=1000,
              d=3,
              p=100,
              prop_miss=0.1,
              citcio=False,
              seed=0,
              d_miwae=3,
              n_epochs=602,
              sig_prior=1,
              add_wy=False,
              num_samples_zmul=200,
              method="glm",
              **kwargs):

    from miwae import miwae

    if model == "lrmf":
        Z, X, w, y, ps = gen_lrmf(n=n,
                                  d=d,
                                  p=p,
                                  citcio=citcio,
                                  prop_miss=prop_miss,
                                  seed=seed)
    elif model == "dlvm":
        Z, X, w, y, ps = gen_dlvm(n=n,
                                  d=d,
                                  p=p,
                                  citcio=citcio,
                                  prop_miss=prop_miss,
                                  seed=seed)
    else:
        raise NotImplementedError(
            "Other data generating models not implemented here yet.")

    X_miss = ampute(X, prop_miss=prop_miss, seed=seed)

    if add_wy:
        xhat, zhat, zhat_mul = miwae(X_miss,
                                     d=d_miwae,
                                     sig_prior=sig_prior,
                                     num_samples_zmul=num_samples_zmul,
                                     n_epochs=n_epochs,
                                     add_wy=add_wy,
                                     w=w,
                                     y=y)
    else:
        xhat, zhat, zhat_mul = miwae(X_miss,
                                     d=d_miwae,
                                     sig_prior=sig_prior,
                                     num_samples_zmul=num_samples_zmul,
                                     n_epochs=n_epochs,
                                     add_wy=add_wy)

    # print('shape of outputs miwae:')
    # print('xhat.shape, zhat.shape, zhat_mul.shape:')
    #    (1000, 200) (1000, 3) (200, 1000, 3)
    print(xhat.shape, zhat.shape, zhat_mul.shape)

    # Tau estimated on Zhat=E[Z|X]
    ps_hat, y0_hat, y1_hat = get_ps_y01_hat(zhat, w, y)
    res_tau_ols = tau_ols(zhat, w, y)
    res_tau_ols_ps = tau_ols_ps(zhat, w, y)
    res_tau_dr = tau_dr(y, w, y0_hat, y1_hat, ps_hat, method)
    lr = LinearRegression()
    lr.fit(zhat, y)
    y_hat = lr.predict(zhat)
    res_tau_resid = tau_residuals(y, w, y_hat, ps_hat, method)

    # Tau estimated on Zhat^(b), l=1,...,B sampled from posterior
    res_mul_tau_dr = []
    res_mul_tau_ols = []
    res_mul_tau_ols_ps = []
    res_mul_tau_resid = []
    for zhat_b in zhat_mul:
        ps_hat, y0_hat, y1_hat = get_ps_y01_hat(zhat_b, w, y)
        res_mul_tau_dr.append(tau_dr(y, w, y0_hat, y1_hat, ps_hat, method))
        res_mul_tau_ols.append(tau_ols(zhat_b, w, y))
        res_mul_tau_ols_ps.append(tau_ols_ps(zhat_b, w, y))
        lr = LinearRegression()
        lr.fit(zhat_b, y)
        y_hat = lr.predict(zhat_b)
        res_mul_tau_resid.append(tau_residuals(y, w, y_hat, ps_hat, method))

    res_mul_tau_dr = np.mean(res_mul_tau_dr)
    res_mul_tau_ols = np.mean(res_mul_tau_ols)
    res_mul_tau_ols_ps = np.mean(res_mul_tau_ols_ps)
    res_mul_tau_resid = np.mean(res_mul_tau_resid)

    if Z.shape[1] == zhat.shape[1]:
        dcor_zhat = dcor(Z, zhat)

    dcor_zhat_mul = []
    for zhat_b in zhat_mul:
        dcor_zhat_mul.append(dcor(Z, zhat_b))
    dcor_zhat_mul = np.mean(dcor_zhat_mul)

    return res_tau_dr, res_tau_ols, res_tau_ols_ps, res_tau_resid, res_mul_tau_dr, res_mul_tau_ols, res_mul_tau_ols_ps, res_mul_tau_resid, dcor_zhat, dcor_zhat_mul
示例#8
0
def exp_baseline(model="dlvm",
                 n=1000,
                 d=3,
                 p=100,
                 prop_miss=0.1,
                 citcio=False,
                 seed=0,
                 full_baseline=False,
                 method="glm",
                 **kwargs):

    if model == "lrmf":
        Z, X, w, y, ps = gen_lrmf(n=n,
                                  d=d,
                                  p=p,
                                  citcio=citcio,
                                  prop_miss=prop_miss,
                                  seed=seed)
    elif model == "dlvm":
        Z, X, w, y, ps = gen_dlvm(n=n,
                                  d=d,
                                  p=p,
                                  citcio=citcio,
                                  prop_miss=prop_miss,
                                  seed=seed)
    else:
        raise NotImplementedError(
            "Other data generating models not implemented here yet.")

    X_miss = ampute(X, prop_miss=prop_miss, seed=seed)

    from sklearn.impute import SimpleImputer
    X_imp_mean = SimpleImputer().fit_transform(X_miss)

    Z_perm = np.random.permutation(Z)
    # Z_rnd = np.random.randn(Z.shape[0], Z.shape[1])

    algo_name = ['Z', 'X']  #, 'X_imp_mean']
    algo_ = [Z, X]  #, X_imp_mean]

    if full_baseline:
        # complete the baseline
        Z_mf = get_U_softimpute(X_miss)
        # need try-except for sklearn version
        try:
            from sklearn.impute import IterativeImputer
            X_imp = IterativeImputer().fit_transform(X_miss)
        except:
            from sklearn.experimental import enable_iterative_imputer
            from sklearn.impute import IterativeImputer
            X_imp = IterativeImputer().fit_transform(X_miss)

        algo_name += ['Z_mf']  #['X_imp','Z_mf']#, 'Z_perm']
        algo_ += [Z_mf]  #[X_imp, Z_mf]#, Z_perm]

    tau = dict()
    for name, zhat in zip(algo_name, algo_):

        if name == 'X_mi':
            res_tau_dr, res_tau_ols, res_tau_ols_ps, res_tau_resid = tau_mi(
                zhat, w, y, method=method)

        else:
            ps_hat, y0_hat, y1_hat = get_ps_y01_hat(zhat, w, y)
            res_tau_ols = tau_ols(zhat, w, y)
            res_tau_ols_ps = tau_ols_ps(zhat, w, y)
            res_tau_dr = tau_dr(y, w, y0_hat, y1_hat, ps_hat, method)
            lr = LinearRegression()
            lr.fit(zhat, y)
            y_hat = lr.predict(zhat)
            res_tau_resid = tau_residuals(y, w, y_hat, ps_hat, method)

        tau[name] = res_tau_dr, res_tau_ols, res_tau_ols_ps, res_tau_resid

    return tau
示例#9
0
def main(unused_argv):
    # Data generating process parameters
    exp_parameter_grid = {
        'model': ["dlvm", "lrmf"] if FLAGS.model is None else [FLAGS.model],
        'citcio': [
            False,
        ],
        'n': [1000, 5000, 10000]
        if FLAGS.n_observations is None else [FLAGS.n_observations],
        'p': [10, 50] if FLAGS.p_ambient is None else [FLAGS.p_ambient],
        'y_snr': [5.] if FLAGS.y_snr is None else [FLAGS.y_snr],
        'x_snr': [2.] if FLAGS.x_snr is None else [FLAGS.x_snr],
        'mu_z': [0.] if FLAGS.mu_z is None else [FLAGS.mu_z],
        'sig_z': [1.] if FLAGS.sig_z is None else [FLAGS.sig_z],
        'sig_xgivenz':
        [0.001] if FLAGS.sig_xgivenz is None else [FLAGS.sig_xgivenz],
        'prop_miss': [
            0.0,
        ] if FLAGS.prop_miss is None else [FLAGS.prop_miss],
        'regularize':
        [False] if FLAGS.regularize is None else [FLAGS.regularize],
        'seed':
        np.arange(FLAGS.n_seeds),
    }
    range_d_over_p = [
        0.002, 0.01, 0.1
    ] if FLAGS.d_over_p is None and FLAGS.d_latent is None else [
        FLAGS.d_over_p
    ]
    range_d = None if range_d_over_p is not None and FLAGS.d_latent is None else FLAGS.d_latent

    # MDC parameters
    range_d_offset = [0, 5, 10] if FLAGS.miwae_d_offset is None else [
        FLAGS.miwae_d_offset
    ]

    mdc_parameter_grid = {
        'mu_prior':
        [0.] if FLAGS.miwae_mu_prior is None else [FLAGS.miwae_mu_prior],
        'sig_prior':
        [1.] if FLAGS.miwae_sig_prior is None else [FLAGS.miwae_sig_prior],
        'num_samples_zmul': [500] if FLAGS.miwae_n_samples_zmul is None else
        [FLAGS.miwae_n_samples_zmul],
        'learning_rate': [
            0.0001,
        ]
        if FLAGS.miwae_learning_rate is None else [FLAGS.miwae_learning_rate],
        'n_epochs': [
            5000,
        ] if FLAGS.miwae_n_epochs is None else [FLAGS.miwae_n_epochs],
    }

    test_seeds = np.arange(FLAGS.n_test_seeds) + 1000
    save_test_data = True if FLAGS.save_test_data is None else FLAGS.save_test_data

    # Experiment and output file name
    output = f'results/{FLAGS.exp_name}.csv' if FLAGS.output is None else FLAGS.output

    FLAGS.log_dir = './sessions/logging/' if FLAGS.log_path is None else FLAGS.log_path
    logging.get_absl_handler().use_absl_log_file()
    logging.info('*' * 20)
    logging.info(f'Starting exp: {FLAGS.exp_name}')
    logging.info('*' * 20)

    exp_arguments = [
        dict(zip(exp_parameter_grid.keys(), vals))
        for vals in itertools.product(*exp_parameter_grid.values())
    ]

    previous_runs = set()
    if tf.io.gfile.exists(output):
        with tf.io.gfile.GFile(output, mode='r') as f:
            reader = csv.DictReader(f)
            for row in reader:
                # Note: we need to do this conversion because DictReader creates an
                # OrderedDict, and reads all values as str instead of bool or int.
                previous_runs.add(
                    str({
                        'model': row['model'],
                        'citcio': row['citcio'] == 'True',
                        'n': int(row['n']),
                        'p': int(row['p']),
                        'y_snr': float(row['y_snr']),
                        'x_snr': float(row['x_snr']),
                        'mu_z': float(row['mu_z']),
                        'sig_z': float(row['sig_z']),
                        'prop_miss': float(row['prop_miss']),
                        'regularize': row['regularize'] == 'True',
                        'seed': int(row['seed']),
                        'd': int(row['d']),
                        'sig_xgivenz': float(row['sig_xgivenz'])
                    }))
    logging.info('Previous runs')
    logging.info(previous_runs)

    for args in exp_arguments:
        # For given p, if range_d is not yet specified,
        # create range for d such that 1 < d < p
        # starting with given ratios for d/p
        if range_d is None:
            range_d = [
                np.maximum(2, int(np.floor(args['p'] * x)))
                for x in range_d_over_p
            ]
            range_d = np.unique(
                np.array(range_d)[np.array(range_d) < args['p']].tolist())

        exp_time = time.time()
        for args['d'] in range_d:
            # We only consider cases where latent dimension <= ambient dimension
            if args['d'] > args['p']:
                continue
            res = []
            if str(args) in previous_runs:
                logging.info(f'Skipped {args}')
                continue
            else:
                logging.info(f'running exp with {args}')

            if args['model'] == "lrmf":
                Z, X, w, y, ps, mu0, mu1 = gen_lrmf(
                    n=args['n'],
                    d=args['d'],
                    p=args['p'],
                    y_snr=args['y_snr'],
                    x_snr=args['x_snr'],
                    citcio=args['citcio'],
                    prop_miss=args['prop_miss'],
                    seed=args['seed'])
            elif args['model'] == "dlvm":
                Z, X, w, y, ps, mu0, mu1 = gen_dlvm(
                    n=args['n'],
                    d=args['d'],
                    p=args['p'],
                    y_snr=args['y_snr'],
                    citcio=args['citcio'],
                    prop_miss=args['prop_miss'],
                    seed=args['seed'],
                    mu_z=args['mu_z'],
                    sig_z=args['sig_z'],
                    x_snr=args['x_snr'],
                    sig_xgivenz=args['sig_xgivenz'])

            X_miss = ampute(X, prop_miss=args['prop_miss'], seed=args['seed'])

            # MIWAE
            mdc_parameter_grid['d_miwae'] = [
                args['d'] + x for x in range_d_offset
            ]

            mdc_arguments = [
                dict(zip(mdc_parameter_grid.keys(), vals))
                for vals in itertools.product(*mdc_parameter_grid.values())
            ]

            for mdc_arg in mdc_arguments:
                t0 = time.time()
                mdc_arg['mu_prior'] = args['mu_z']
                session_file = './sessions/' + \
                                    args['model'] + '_'+ \
                                    '_sigXgivenZ' + str(args['sig_xgivenz']) + \
                                    '_n' + str(args['n']) + \
                                    '_p' + str(args['p']) + \
                                    '_d' + str(args['d']) + \
                                    '_ysnr' + str(args['y_snr']) +\
                                    '_xsnr' + str(args['x_snr']) +\
                                    '_propNA' + str(args['prop_miss']) + \
                                    '_seed' + str(args['seed'])
                session_file_complete = session_file + \
                                        '_dmiwae' + str(mdc_arg['d_miwae']) + \
                                        '_sigprior' + str(mdc_arg['sig_prior'])
                epochs = -1
                tmp = glob.glob(session_file_complete + '.*')
                sess = tf.Session(graph=tf.reset_default_graph())
                if len(tmp) > 0:
                    new_saver = tf.train.import_meta_graph(
                        session_file_complete + '.meta')
                    new_saver.restore(sess, session_file_complete)
                    #with open(session_file_complete+'.pkl', 'rb') as f:
                    #    xhat, zhat, zhat_mul, elbo, epochs = pickle.load(f)
                else:
                    xhat, zhat, zhat_mul, elbo, epochs = miwae_es(
                        X_miss,
                        d_miwae=mdc_arg['d_miwae'],
                        mu_prior=mdc_arg['mu_prior'],
                        sig_prior=mdc_arg['sig_prior'],
                        num_samples_zmul=mdc_arg['num_samples_zmul'],
                        l_rate=mdc_arg['learning_rate'],
                        n_epochs=mdc_arg['n_epochs'],
                        save_session=True,
                        session_file=session_file)
                    new_saver = tf.train.import_meta_graph(
                        session_file_complete + '.meta')
                    new_saver.restore(sess, session_file_complete
                                      )  #tf.train.latest_checkpoint('./'))
                    with open(session_file_complete + '.pkl',
                              'wb') as file_data:  # Python 3: open(..., 'wb')
                        pickle.dump([xhat, zhat, zhat_mul, elbo, epochs],
                                    file_data)

                args['training_time'] = int(time.time() - t0)

                # Evaluate performance of trained model on new testsets
                graph = tf.get_default_graph()

                K = graph.get_tensor_by_name('K:0')
                x = graph.get_tensor_by_name('x:0')
                batch_size = tf.shape(x)[0]
                xms = graph.get_tensor_by_name('xms:0')
                imp_weights = graph.get_tensor_by_name('imp_weights:0')
                xm = tf.einsum('ki,kij->ij', imp_weights, xms, name='xm')

                zgivenx_flat = graph.get_tensor_by_name('zgivenx_flat:0')
                zgivenx = tf.reshape(zgivenx_flat,
                                     [K, batch_size, zgivenx_flat.shape[1]])
                z_hat = tf.einsum('ki,kij->ij',
                                  imp_weights,
                                  zgivenx,
                                  name='z_hat')

                sir_logits = graph.get_tensor_by_name('sir_logits:0')
                sirz = tfd.Categorical(logits=sir_logits).sample(
                    mdc_arg['num_samples_zmul'])
                zmul = graph.get_tensor_by_name('zmul:0')

                for test_seed in test_seeds:
                    if args['model'] == "lrmf":
                        (Z_test, X_test, w_test, y_test, ps_test, mu0_test,
                         mu1_test) = gen_lrmf(n=args['n'],
                                              d=args['d'],
                                              p=args['p'],
                                              y_snr=args['y_snr'],
                                              citcio=args['citcio'],
                                              prop_miss=args['prop_miss'],
                                              seed=test_seed)
                    elif args['model'] == "dlvm":
                        (
                            Z_test, X_test, w_test, y_test, ps_test, mu0_test,
                            mu1_test
                        ) = gen_dlvm(
                            n=args['n'],
                            d=args['d'],
                            p=args['p'],
                            y_snr=args['y_snr'],
                            citcio=args['citcio'],
                            prop_miss=args[
                                'prop_miss'],  # this argument is only used if citcio=True
                            seed=test_seed,
                            mu_z=args['mu_z'],
                            sig_z=args['sig_z'],
                            x_snr=args['x_snr'],
                            sig_xgivenz=args['sig_xgivenz'])

                    X_miss_test = ampute(X_test,
                                         prop_miss=args['prop_miss'],
                                         seed=args['seed'])
                    mask_test = np.isfinite(
                        X_miss_test
                    )  # binary mask that indicates which values are missing

                    t0 = time.time()
                    tmp_elm_pkl = glob.glob(session_file_complete +
                                            '_testset_eval' + str(test_seed) +
                                            '.pkl')
                    if len(tmp_elm_pkl) > 0:
                        with open(
                                session_file_complete + '_testset_eval' +
                                str(test_seed) + '.pkl', 'rb') as f:
                            xhat_test, zhat_test, zgivenx_test, zhat_mul_test = pickle.load(
                                f)
                    else:
                        x_test_imp0 = np.copy(X_miss_test)
                        x_test_imp0[np.isnan(X_miss_test)] = 0

                        n_test = X_test.shape[0]
                        xhat_test = np.copy(x_test_imp0)
                        zhat_test = np.zeros([n_test, mdc_arg['d_miwae']])
                        zgivenx_test = np.tile(
                            zhat_test, [mdc_arg['num_samples_zmul'], 1, 1])
                        zhat_mul_test = np.tile(
                            zhat_test, [mdc_arg['num_samples_zmul'], 1, 1])

                        for i in range(n_test):
                            zgivenx_test[:, i, :] = np.squeeze(
                                zgivenx.eval(session=sess,
                                             feed_dict={
                                                 'x:0':
                                                 x_test_imp0[i, :].reshape(
                                                     [1, args['p']]),
                                                 'K:0':
                                                 mdc_arg['num_samples_zmul'],
                                                 'xmask:0':
                                                 mask_test[i, :].reshape(
                                                     [1, args['p']])
                                             })).reshape([
                                                 mdc_arg['num_samples_zmul'],
                                                 mdc_arg['d_miwae']
                                             ])
                            xhat_test[i, :] = xm.eval(
                                session=sess,
                                feed_dict={
                                    'x:0':
                                    x_test_imp0[i, :].reshape([1, args['p']]),
                                    'K:0':
                                    10000,
                                    'xmask:0':
                                    mask_test[i, :].reshape([1, args['p']])
                                })
                            zhat_test[i, :] = z_hat.eval(
                                session=sess,
                                feed_dict={
                                    'x:0':
                                    x_test_imp0[i, :].reshape([1, args['p']]),
                                    'K:0':
                                    10000,
                                    'xmask:0':
                                    mask_test[i, :].reshape([1, args['p']])
                                })
                            si, zmu = sess.run(
                                [sirz, zmul],
                                feed_dict={
                                    'x:0':
                                    x_test_imp0[i, :].reshape([1, args['p']]),
                                    'K:0':
                                    10000,
                                    'xmask:0':
                                    mask_test[i, :].reshape([1, args['p']])
                                })
                            zhat_mul_test[:, i, :] = np.squeeze(
                                zmu[si, :, :]).reshape(
                                    (mdc_arg['num_samples_zmul'],
                                     mdc_arg['d_miwae']))

                        if save_test_data:
                            with open(
                                    session_file_complete + '_testset_eval' +
                                    str(test_seed) + '.pkl', 'wb'
                            ) as file_data:  # Python 3: open(..., 'wb')
                                pickle.dump([
                                    xhat_test, zhat_test, zgivenx_test,
                                    zhat_mul_test
                                ], file_data)

                    evaluation_time = int(time.time() - t0)

                    if args['d'] == 1 and mdc_arg['d_miwae'] == 1:
                        row = {
                            'Z_cor':
                            pearsonr(Z_test.reshape([
                                args['n'],
                            ]), zhat_test.reshape([
                                args['n'],
                            ]))[0]
                        }
                    else:
                        row = {'Z_cor': np.NaN}
                    if args['d'] == mdc_arg['d_miwae']:
                        row.update({'Z_mmd': mmd(Z_test, zhat_test, beta=1.)})
                        row.update({'Z_rvcoef': compute_rv(Z_test, zhat_test)})
                    else:
                        row.update({'Z_mmd': np.NaN})
                        row.update({'Z_rvcoef': np.NaN})
                    row.update(
                        {'X_mse': mean_squared_error(X_test, xhat_test)})
                    row.update({'X_mmd': mmd(X_test, xhat_test, beta=1.)})
                    row.update({'X_rvcoef': compute_rv(X_test, xhat_test)})
                    row.update(args)
                    row.update(mdc_arg)
                    row.update({'epochs': epochs})
                    row.update({'test_seed': test_seed})
                    row.update({'evaluation_time': evaluation_time})
                    res.append(row)

            log_res(
                output, res,
                l_metrics + list(args.keys()) + list(mdc_arg.keys()) +
                ['epochs', 'test_seed', 'evaluation_time'])
            logging.info('........... DONE')
            logging.info(f'in {time.time() - exp_time} s \n\n')

    logging.info('*' * 20)
    logging.info(f'Exp: {FLAGS.exp_name} succesfully ended.')
    logging.info('*' * 20)
示例#10
0
                X_new=X_reconstruction,
                missing_mask=missing_mask)
            X_filled[missing_mask] = X_reconstruction[missing_mask]
            if converged:
                break
        if self.verbose:
            print("[SoftImpute] Stopped after iteration %d for lambda=%f" % (
                i + 1,
                shrinkage_value))

            
        self.mae_obs = masked_mae(
                    X_true=X_init,
                    X_pred=X_reconstruction,
                    mask=observed_mask)

        return X_filled


if __name__=='__main__':

    from generate_data import gen_lrmf, gen_dlvm
    from generate_data import ampute
    import matplotlib.pyplot as plt
    import seaborn as sns

    Z, X, w, y, ps = gen_lrmf(d=3)
    X_obs = ampute(X)

    print('boxplot of get_U_softimpute with gen_lrmf(d=3)')
    U = get_U_softimpute(X_obs, boxplot=True)