Beispiel #1
0
    def receiveResults(self):

        resultsReceived = results_pb2.Results()

        print('Waiting results...')
        print(self.receiver)
        message = self.receiver.recv()
        print('Got results')
        resultsReceived.ParseFromString(message)

        self.resultQueue.put(resultsReceived)
def load_results(result_directory):  # noqa

    data = {}
    results = results_pb2.Results()
    for filename in os.listdir(result_directory):
        with open(os.path.join(result_directory, filename), 'rb') as f:
            try:
                results.ParseFromString(f.read())
            except Exception as e:
                # Checking if is cluster stderr file.
                if filename.endswith('.err'):
                    continue
                raise e
            for method in results.methods:
                mse_key = '%s_mse' % method.method_name
                if mse_key not in data:
                    data[mse_key] = []
                data[mse_key].append(np.array(method.mse))

    return data
Beispiel #3
0
instances = set()
lines = sys.stdin.readlines()

idx = 0
# read the results
while idx < len(lines):
    words = lines[idx].strip().split()
    n = float(words[0])
    rho = float(words[1])
    inst = words[2]
    nAnts.add(n)
    rhos.add(rho)
    instances.add(inst)
    rs = []
    for rf in words[3:]:
        rb = results_pb2.Results()
        with open(rf, "rb") as f:
            rb.ParseFromString(f.read())
            rs.append(rb)
    results[(inst, n, rho)] = rs
    idx = idx + 1

outfile = open(sys.argv[1] + '/aco-heatmap', 'w')
outfile.write('nAnts rho')
for inst in range(len(instances)):
    outfile.write(' ' + 'inst' + str(inst))
outfile.write('\n')

for n in nAnts:
    for rho in rhos:
        outfile.write(str(n) + ' ' + str(rho))
Beispiel #4
0
def timeToBaseQ(s):
  fracs=[]
  for (i,(b,r)) in results.iteritems():
    baseR = b.point[-1]
    fracs.append( float(timeTo(baseR.qual,r[s])) / float(baseR.time) )
  return sum(fracs) / float(len(fracs))

idx=0
while idx < len(lines):
  # read an instance
  instanceName = lines[idx].strip()
  print "reading " + instanceName
  idx=idx+1
  instanceBaseLineFile = lines[idx].strip()
  idx=idx+1
  baseLineR = results_pb2.Results()
  with open(instanceBaseLineFile, "rb") as f:
    baseLineR.ParseFromString(f.read())

  if len(baseLineR.point) > 0:
    results[instanceName] = (baseLineR,{})

  # read all alternative methods results
  while idx < len(lines) and lines[idx].strip() != '':
    solverName = lines[idx].strip()
    solvers.add(solverName)
    idx=idx+1
    # TODO: Take an average if needed
    solverResultsFile = lines[idx].strip()
    idx=idx+1
    solverResults = results_pb2.Results()
def main():  # noqa

    if FLAGS.seed is not None:
        np.random.seed(FLAGS.seed)

    n_actions = 2
    horizon = 5
    n_methods = 2 * (horizon + 1)
    mdp = mdps.SinglePathMDP(n_actions, horizon, stochastic=True)
    pie = mdp.get_policy(0)
    pib = mdp.get_policy(1)
    print(pie)
    print(pib)
    paths = []

    results = results_pb2.Results()
    results.experiment_name = 'RIS_experiment'

    t_val = mdps.evaluate(mdp, pie)
    print('True value: %f' % t_val)

    IS = estimators.ISEstimate(pie, pib)
    WIS = estimators.ISEstimate(pie, pib, weighted=True)
    REG = estimators.REGEstimate(pie, mdp)
    methods = []
    labels = []
    for i in range(horizon):
        n = i + 1
        methods.append(estimators.RISEstimate(pie, n_actions, horizon, n=n))
        methods.append(
            estimators.RISEstimate(pie, n_actions, horizon, n=n,
                                   weighted=True))

        labels.append('RIS(%d)' % n)
        labels.append('Weighted RIS(%d)' % n)

    n_evals = int(FLAGS.num_iters / FLAGS.eval_freq)

    is_mses = np.zeros(n_evals)
    is_estimates = np.zeros(n_evals)
    is_variances = np.zeros(n_evals)
    wis_mses = np.zeros(n_evals)
    wis_estimates = np.zeros(n_evals)
    reg_mses = np.zeros(n_evals)
    reg_estimates = np.zeros(n_evals)
    mses = np.zeros((n_methods, n_evals))
    variances = np.zeros((n_methods, n_evals))
    estimates = np.zeros((n_methods, n_evals))
    lens = []
    idx = 0

    for itr in range(FLAGS.num_iters):
        path, G = mdps.sample(mdp, pib)

        paths.append(path)
        if itr % FLAGS.eval_freq == 0 and itr > 0:
            # idx = int(itr / FLAGS.eval_freq)
            eval_estimators(paths, idx, t_val, IS, is_estimates, is_mses,
                            is_variances, WIS, wis_estimates, wis_mses, REG,
                            reg_estimates, reg_mses, methods, estimates, mses,
                            variances, labels)
            idx += 1

    eval_estimators(paths, idx, t_val, IS, is_estimates, is_mses, is_variances,
                    WIS, wis_estimates, wis_mses, REG, reg_estimates, reg_mses,
                    methods, estimates, mses, variances, labels)

    # Normal IS methods
    method = results.methods.add()
    method.method_name = 'IS'
    method.estimates.extend(is_estimates)
    method.mse.extend(is_mses)
    method.variances.extend(is_variances)
    method = results.methods.add()
    method.method_name = 'WIS'
    method.estimates.extend(wis_estimates)
    method.mse.extend(wis_mses)
    method = results.methods.add()
    method.method_name = 'REG'
    method.mse.extend(reg_mses)
    method.estimates.extend(reg_estimates)

    # Add RIS methods
    for i in range(2 * horizon):
        method = results.methods.add()
        method.method_name = labels[i]
        method.estimates.extend(estimates[i])
        method.mse.extend(mses[i])
        method.variances.extend(variances[i])

    if FLAGS.result_file:
        with open(FLAGS.result_file, 'wb') as w:
            w.write(results.SerializeToString())

    # Plotting code
    if FLAGS.plot:
        lens = np.arange(n_evals) * FLAGS.eval_freq
        plt.plot(lens, is_mses, label='IS')
        plt.plot(lens, wis_mses, label='WIS')
        plt.plot(lens, reg_mses, label='REG')
        for i in range(2 * horizon):
            plt.plot(lens, mses[i], label=labels[i])
        plt.legend()
        plt.show()