Esempio n. 1
0
def gen_testloss(args):
    # load data and model
    params = utils.load_params(args.model_dir)
    ckpt_dir = os.path.join(args.model_dir, 'checkpoints')
    ckpt_paths = [int(e[11:-3]) for e in os.listdir(ckpt_dir) if e[-3:] == ".pt"]
    ckpt_paths = np.sort(ckpt_paths)
    
    # csv
    headers = ["epoch", "step", "loss", "discrimn_loss_e", "compress_loss_e", 
        "discrimn_loss_t",  "compress_loss_t"]
    csv_path = utils.create_csv(args.model_dir, 'losses_test.csv', headers)
    print('writing to:', csv_path)

    # load data
    test_transforms = tf.load_transforms('test')
    testset = tf.load_trainset(params['data'], test_transforms, train=False)
    testloader = DataLoader(testset, batch_size=params['bs'], shuffle=False, num_workers=4)
    
    # save loss
    criterion = MaximalCodingRateReduction(gam1=params['gam1'], gam2=params['gam2'], eps=params['eps'])
    for epoch, ckpt_path in enumerate(ckpt_paths):
        net, epoch = tf.load_checkpoint(args.model_dir, epoch=epoch, eval_=True)
        for step, (batch_imgs, batch_lbls) in enumerate(testloader):
            features = net(batch_imgs.cuda())
            loss, loss_empi, loss_theo = criterion(features, batch_lbls, 
                                            num_classes=len(testset.num_classes))
            utils.save_state(args.model_dir, epoch, step, loss.item(), 
                *loss_empi, *loss_theo, filename='losses_test.csv')
    print("Finished generating test loss.")
Esempio n. 2
0
def gen_training_accuracy(args):
    # load data and model
    params = utils.load_params(args.model_dir)
    ckpt_dir = os.path.join(args.model_dir, 'checkpoints')
    ckpt_paths = [int(e[11:-3]) for e in os.listdir(ckpt_dir) if e[-3:] == ".pt"]
    ckpt_paths = np.sort(ckpt_paths)
    
    # csv
    headers = ["epoch", "acc_train", "acc_test"]
    csv_path = utils.create_csv(args.model_dir, 'accuracy.csv', headers)

    for epoch, ckpt_paths in enumerate(ckpt_paths):
        if epoch % 5 != 0:
            continue
        net, epoch = tf.load_checkpoint(args.model_dir, epoch=epoch, eval_=True)
        # load data
        train_transforms = tf.load_transforms('test')
        trainset = tf.load_trainset(params['data'], train_transforms, train=True)
        trainloader = DataLoader(trainset, batch_size=500, num_workers=4)
        train_features, train_labels = tf.get_features(net, trainloader, verbose=False)

        test_transforms = tf.load_transforms('test')
        testset = tf.load_trainset(params['data'], test_transforms, train=False)
        testloader = DataLoader(testset, batch_size=500, num_workers=4)
        test_features, test_labels = tf.get_features(net, testloader, verbose=False)

        acc_train, acc_test = svm(args, train_features, train_labels, test_features, test_labels)
        utils.save_state(args.model_dir, epoch, acc_train, acc_test, filename='accuracy.csv')
    print("Finished generating accuracy.")
Esempio n. 3
0
def quit(path, out):
    '''terminate the script without wrtiting to output file, moreover
    creates a checkpoint to eventually resume the work'''
    out.close()
    cv2.destroyAllWindows()

    # save current work
    utils.save_state(path)
def update_raw_string_bucket(update_clicks, modal_validity, bucket_id,
                             bucket_name, search_results, result_clicks):
    modal_validity = load_state(modal_validity)
    validity = modal_validity.get('validity')

    data = load_state(search_results).get('data')
    button_clicked = load_state(result_clicks).get("button_clicked")

    if validity is None:
        return save_state({'updated': False})

    if validity == INVALID:
        return save_state({'updated': False})

    bucket_id = bucket_id if bucket_id is not None else ""
    bucket_name = bucket_name if bucket_name is not None else ""

    valid_bucket_id = modal_validity.get("bucket_id")
    valid_bucket_name = modal_validity.get("bucket_name")

    # It's okay if valid bucket name is none. Means it's a new bucket.
    if validity == NEW:
        if valid_bucket_id != bucket_id:
            return save_state({'updated': False})

    if validity == EXISTS:
        if valid_bucket_id != bucket_id or valid_bucket_name != valid_bucket_name:
            return save_state({'updated': False})

    raw_string = data["raw string"][button_clicked].replace("'", "\'")

    if sql.execute(
            f"SELECT count(*) FROM reference.organization_buckets_edits WHERE raw_string = '{raw_string}';"
    )['count'][0] > 0:
        sql.execute(
            f"UPDATE reference.organization_buckets_edits SET bucket = '{bucket_name}' WHERE raw_string = '{raw_string}';"
        )
        sql.execute(
            f"UPDATE reference.organization_buckets_edits SET bucket_id = '{bucket_id}' WHERE raw_string = '{raw_string}';"
        )

    else:
        sql.execute(
            f"INSERT INTO reference.organization_buckets_edits (raw_string, bucket, bucket_id, time) VALUES ('{raw_string}', '{bucket_name}', '{bucket_id}', GETDATE());"
        )

    print(f"Database updated: {raw_string} to {bucket_id} ({bucket_name})")
    return save_state({'updated': True})
def merge_button_press(*args):
    results = load_state(args[-2])
    data = results.get("data")
    previous_clicks = load_state(args[-1]).get("clicks")
    clicks = list(args[:-2])

    button_clicked = None

    for i, click in enumerate(clicks):
        if click is None:
            clicks[i] = 0

    # if not previous_clicks:
    previous_clicks = [0] * len(clicks)

    for i, click, previous_click in zip(range(len(clicks)), clicks,
                                        previous_clicks):
        # print(f"{click} vs {previous_click}")
        if click > previous_click:
            button_clicked = i
            break

    state = {
        "clicks": clicks,
        "button_clicked": button_clicked,
        "bucket_id": data["bucket id"][button_clicked],
    }

    return [merge.render_modal(button_clicked, data), save_state(state)]
def update_validity(bucket_id_submits, validate_clicks, is_open, bucket_id):
    bucket_name = None

    if bucket_id is None:
        validity = INVALID

    elif len(bucket_id) < 3:
        validity = INVALID

    elif not is_open:
        validity = INVALID
        bucket_id = None

    else:
        print("Checking if exists via database.")
        exists = sql.execute(
            f"SELECT bucket FROM staging.organization_buckets WHERE bucket_id = '{bucket_id}' LIMIT 1;"
        )

        if exists.shape[0] > 0:
            validity = EXISTS
            bucket_name = exists['bucket'][0]

        else:
            validity = NEW

    return save_state({
        "validity": validity,
        "bucket_name": bucket_name,
        "bucket_id": bucket_id
    })
def update_update_button_clicks(*args):

    clicks = args[:-1]
    previous_state = load_state(args[-1])

    button_clicked = -1

    # If previous state doesn't exist, then just look for a click:
    if not args[-1]:
        for i, click in enumerate(clicks):
            if click:
                button_clicked = i
                break

    # If previous state does exist, we need to find the click that increased.
    else:
        previous_clicks = previous_state.get("clicks")
        for i, n_click, last_n_click in zip(range(len(clicks)), clicks,
                                            previous_clicks):
            if n_click:
                if n_click > last_n_click:
                    button_clicked = i

    # Ensure we've got no "None" clicks to store:
    clicks = list(clicks)
    for i, click in enumerate(clicks):
        if not click:
            clicks[i] = 0

    # Save state:
    return save_state({"clicks": clicks, "button_clicked": button_clicked})
Esempio n. 8
0
def train(config):
    train_acc_per_epoch = OrderedDict()  # store train accuracy history
    print('number of trained parameters', sum(p.numel() for p in config['network'].parameters() if p.requires_grad))
    for epoch in range(1, config['epochs'] + 1):
        time1 = time()  # time per epoch
        config['network'].train()  # train mode
        acc = 0
        for idx, (X_batch, y_batch) in enumerate(config['data_train']):
            config['optimizer'].zero_grad()
            X_batch = X_batch.to(config['device']).view(X_batch.shape[0], -1)
            y_batch = y_batch.to(config['device'])
            y_pred = config['network'](X_batch).to(config['device'])
            loss = F.nll_loss(y_pred, y_batch)
            loss.backward()
            config['optimizer'].step()
            predicted = torch.argmax(y_pred.data, 1)
            acc += (predicted == y_batch).sum().item()
        if epoch % config['save_every'] == 0:
            train_acc_per_epoch[str(epoch)] = utils.save_state(config, epoch, acc)
        print(f"Epoch {epoch} - "
              f"Accuracy: {round(100 * acc / config['len_train'], 3):.3f}% - "
              f"Time(epoch): {timedelta(seconds=time() - time1)} - "
              f"Time(tot): {timedelta(seconds=time() - time0)}")

    return train_acc_per_epoch
def update_state_search(n_clicks, raw_string_n_submits, bucket_name_n_submits,
                        bucket_id_n_submits, database_updated, previous_state,
                        raw_string, bucket_name, bucket_id):
    # print(f"Called Update State of Search Bar {raw_string}")
    previous_state = load_state(previous_state)

    # Determine Whether to Run Search:
    # We require one of n_clicks/n_submits to be not None
    # The callback is called once before anything is populated.
    # In this case, we just ignore it. Otherwise, we want to run search.

    run_search = True if n_clicks else \
        True if raw_string_n_submits else \
        True if bucket_name_n_submits else \
        True if bucket_id_n_submits else False

    # If we run the search, do it!
    if run_search:
        raw_string = None if raw_string == '' else raw_string
        bucket_name = None if bucket_name == '' else bucket_name
        bucket_id = None if bucket_id == '' else bucket_id

        if raw_string or bucket_name or bucket_id:
            data = sql.execute(f"""
                SELECT 
                    raw_string as "Raw String"
                    , coalesce(edits.bucket, original.bucket) as "Bucket Name"
                    , coalesce(edits.bucket_id, original.bucket_id) as "Bucket ID"
                    , original.bucket_id as "Original Bucket ID"
                    , original.has_new_bucket as "New Bucket"
                FROM staging.organization_buckets original
                LEFT JOIN reference.organization_buckets_edits edits USING(raw_string)
                WHERE original.raw_string <> '' AND original.bucket <> '' {
                    f" AND (original.raw_string ~* '{raw_string}' ) " if raw_string else ''
                }{
                    f" AND (original.bucket ~* '{bucket_name}' ) " if bucket_name else ''
                }
                {
                    f" AND (original.bucket_id ~* '{bucket_id}' ) " if bucket_id else ''
                }
                ORDER BY CASE WHEN original.has_new_bucket THEN 1 ELSE 0 END DESC, 3,2,1
                LIMIT {results.N_ROWS}
            """)
        else:
            data = None
    else:
        data = None

    # Write out the current search state.
    new_state = {
        "raw_string": raw_string,
        "bucket_name": bucket_name,
        "bucket_id": bucket_id,
        "run_search": run_search,
        "data": data
    }

    return (save_state(new_state),
            results.generate_results_table(new_state.get("data")))
Esempio n. 10
0
def get_active_exp(env, threshold, ae, xm, xs, render, take_max = False, max_act_steps = 20):

    err_avg = 0
    for i in range(20):
        state = env.reset()
        state = robot_reset(env)
        error = ae.error((np.concatenate((state["observation"],
                                    state["achieved_goal"],
                                    state["desired_goal"])).reshape((1,-1)) - xm)/xs)
        err_avg+=error
    err_avg/=20

    state = env.reset()
    error = ae.error((np.concatenate((state["observation"],
                                    state["achieved_goal"],
                                    state["desired_goal"])).reshape((1,-1)) - xm)/xs)
    #print("predicted error", error)

    if not take_max:
        tried = 0
        while not error > threshold*err_avg:
            tried+=1
            state = env.reset()
            state = robot_reset(env)
            error = ae.error((np.concatenate((state["observation"],
                                            state["achieved_goal"],
                                            state["desired_goal"])).reshape((1,-1)) - xm)/xs)
      #      print("predicted error", error.numpy(), err_avg.numpy())
     #   print("Tried ", tried, " initial states")
        new_states, new_acts = man_controller.get_demo(env, state, CTRL_NORM, render)

        return new_states, new_acts

    else:
        errs_states = []
        for k in range(max_act_steps):
            state = env.reset()
            state = robot_reset(env)
            error = ae.error((np.concatenate((state["observation"],
                                            state["achieved_goal"],
                                            state["desired_goal"])).reshape((1,-1)) - xm)/xs)
            s, g = utils.save_state(env)
            errs_states.append([s, g, error])

        max_error = -1000
        max_key = ()
        for el in range(len(errs_states)):
            if errs_states[el][2] > max_error:
                max_error = errs_states[el][2]
                max_key = el

        new_env = utils.set_state(env, errs_states[max_key][0], errs_states[max_key][1])
        state, *_ = new_env.step(np.zeros(4))

        new_states, new_acts = man_controller.get_demo(new_env, state, CTRL_NORM, render)

        return new_states, new_acts
Esempio n. 11
0
    def save(self, path=None):
        """Save model to a pickle located at `path`"""
        if path is None:
            path = os.path.join(logger.get_dir(), "model.pkl")

        with tempfile.TemporaryDirectory() as td:
            save_state(os.path.join(td, "model"))
            arc_name = os.path.join(td, "packed.zip")
            with zipfile.ZipFile(arc_name, 'w') as zipf:
                for root, dirs, files in os.walk(td):
                    for fname in files:
                        file_path = os.path.join(root, fname)
                        if file_path != arc_name:
                            zipf.write(file_path,
                                       os.path.relpath(file_path, td))
            with open(arc_name, "rb") as f:
                model_data = f.read()
        with open(path, "wb") as f:
            cloudpickle.dump((model_data, self._act_params), f)
Esempio n. 12
0
def train(train, model, criterion, optimizer, n_lettres, n_epochs, log_dir,
          checkpoint_path):
    losses = []
    writer = SummaryWriter(log_dir=log_dir)

    pbar = tqdm(range(n_epochs), total=n_epochs, file=sys.stdout)

    state = load_state(checkpoint_path, model, optimizer)

    for i in pbar:
        l = []
        for x, y in train:

            x = x.squeeze(-1).permute(1, 0, -1).to(device)
            seq_len, batch_size, embeding = x.shape

            y = y.view(seq_len * batch_size).to(device)

            o = state.model(x, state.model.initHidden(batch_size).to(device))
            d = state.model.decode(o).view(seq_len * batch_size, embeding)

            loss = criterion(d, y)
            loss.backward()

            state.optimizer.step()
            state.optimizer.zero_grad()

            l.append(loss.item())

            state.iteration += 1

        state.epoch += 1
        save_state(checkpoint_path, state)

        lo = np.mean(l)
        losses.append(lo)
        # \tTest: Loss: {np.round(test_lo, 4)}
        pbar.set_description(f'Train: Loss: {np.round(lo, 4)}')

        writer.add_scalar('Loss/train', lo, i)

    return losses
Esempio n. 13
0
    def train(self):
        self.agent.start_interaction(self.envs, nlump=self.hps['nlumps'], dynamics=self.dynamics)
        while True:
            info = self.agent.step()
            if info['update']:
                logger.logkvs(info['update'])
                logger.dumpkvs()
            if self.agent.rollout.stats['tcount'] == 0:
                fname = os.path.join(self.hps['save_dir'], 'checkpoints')
                if os.path.exists(fname+'.index'):
                    load_state(fname)
                    print('load successfully')
                else:
                    print('fail to load')
            if self.agent.rollout.stats['tcount']%int(self.num_timesteps/self.num_timesteps)==0:
                fname = os.path.join(self.hps['save_dir'], 'checkpoints')
                save_state(fname)
            if self.agent.rollout.stats['tcount'] > self.num_timesteps:
                break
            # print(self.agent.rollout.stats['tcount'])

        self.agent.stop_interaction()
def merge_bucket(n_clicks, new_bucket_id, merge_state):
    merge_state = load_state(merge_state)

    is_valid = merge.validate(new_bucket_id)
    status = False
    old_bucket_id = None
    if is_valid:
        old_bucket_id = merge_state.get("bucket_id")
        status = merge.merge(old_bucket_id, new_bucket_id)

    # print(status)
    return save_state({
        'status': status,
    })
Esempio n. 15
0
def run(dispatcher, server, iterations, filename):
    """
    Run the experiment using provided dispatcher and server
    """

    for iteration in range(iterations):

        # fetch the results of a single client run
        (grads, client_stamp, client) = dispatcher.fetch_update(iteration)

        # give those results to the parameter server
        params, server_stamp, unblock, v = server.apply_update(
            grads, client_stamp, client)

        # give the (potentially) new parameters back to that client
        # all other clients are by definition still working
        # TODO: should client wait for potential param update?
        dispatcher.update_parameters(params, client, server_stamp, unblock, v)

        # optionally run validation and report stats
        dispatcher.validate(params, server_stamp, unblock)

    save_state(dispatcher, filename)
Esempio n. 16
0
def run(dispatcher, server, iterations, filename):
    """
    Run the experiment using provided dispatcher and server
    """

    for iteration in range(iterations):

        # fetch the results of a single client run
        (grads, client_stamp, client) = dispatcher.fetch_update(iteration)

        # give those results to the parameter server
        params, server_stamp, unblock, v = server.apply_update(grads,
                                                               client_stamp,
                                                               client)

        # give the (potentially) new parameters back to that client
        # all other clients are by definition still working
        # TODO: should client wait for potential param update?
        dispatcher.update_parameters(params, client, server_stamp, unblock, v)

        # optionally run validation and report stats
        dispatcher.validate(params, server_stamp, unblock)

    save_state(dispatcher, filename)
def train(config):
    print(
        'number of trained parameters',
        sum(p.numel() for p in config['network'].parameters()
            if p.requires_grad))
    load_this_model = ''
    for epoch in range(1, config['epochs'] + 1):
        dest_path = f'{config["save_dir"]}/weights/{epoch}.pth'
        full_dest_path = Path(str(Path().absolute()) + "/" + dest_path)
        if Path.exists(full_dest_path):
            load_this_model = full_dest_path
            continue
        if load_this_model != '':
            config['network'].load_state_dict(torch.load(load_this_model))
        time1 = time()  # time per epoch
        config['network'].train()  # train mode
        acc, loss = 0, 0
        for idx, (X_batch, y_batch) in enumerate(config['data_train']):
            config['optimizer'].zero_grad()
            X_batch = X_batch.to(config['device']).view(X_batch.shape[0], -1)
            y_batch = y_batch.to(config['device'])
            y_pred, _, _, params = config['network'](X_batch)
            loss = F.nll_loss(y_pred, y_batch)
            if config['lambda_l1'] != 0:
                reg_loss = F.l1_loss(params,
                                     target=torch.zeros_like(params),
                                     reduction='sum')
                loss += reg_loss * config['lambda_l1']
            loss.backward()
            config['optimizer'].step()
            predicted = torch.argmax(y_pred.data, 1)
            acc += (predicted == y_batch).sum().item()
        if epoch % config['save_every'] == 0:
            # train_acc_per_epoch[str(epoch)] = utils.save_state(config, dest_path, acc)
            acc = utils.save_state(config, dest_path, acc)
            utils.save_results(config, acc, None, epoch, loss.item())
        print(f"Epoch {epoch} - "
              f"Accuracy: {round(acc, 3):.3f}% - "
              f"Loss: {loss:.3f} - "
              f"Time(epoch): {timedelta(seconds=time() - time1)} - "
              f"Time(tot): {timedelta(seconds=time() - time0)}")
Esempio n. 18
0
def go(seed, file):
    if not tf.__version__ == "2.0.0-alpha0":
        tf.random.set_random_seed(seed)
    else:
        tf.random.set_seed(seed)
    env = gym.make("FetchPickAndPlace-v1")
    env.seed(seed)
    test_set = []
    for i in range(TEST_EPS):
        state = env.reset()
        state, goal = utils.save_state(env)
        test_set.append((state, goal))

    states, actions = get_experience(INITIAL_TRAIN_EPS, env)
    print("Normal states, actions ", len(states), len(actions))

    net = model.BCModel(states[0].shape[0],
                        actions[0].shape[0],
                        BC_HD,
                        BC_HL,
                        BC_LR,
                        set_seed=seed)

    x = np.array(states)
    xm = x.mean()
    xs = x.std()
    x = (x - x.mean()) / x.std()

    a = np.array(actions)
    am = a.mean()
    ast = a.std()
    a = (a - a.mean()) / a.std()

    net.train(x, a, BC_BS, BC_EPS)

    result_t = test(net, test_set, env, xm, xs, am, ast, False)
    print("Normal learning results ", seed, " : ", result_t)
    file.write(
        str("Normal learning results " + str(seed) + " : " + str(result_t)))
def update_search_results(id_submits, name_submits, n_clicks, database_update,
                          bucket_id, bucket_name, previous_search):

    previous_search = load_state(previous_search)
    database_update = load_state(database_update)

    # n_triggers = id_submits if id_submits else 0
    #     + name_submits if name_submits else 0
    #     + n_clicks if n_clicks else 0

    # previous_triggers = previous_search.get("triggers", 0)

    # if n_triggers > previous_triggers or database_update.get("status"):
    search_results = results.search(bucket_id, bucket_name)

    state = save_state({
        'data': search_results,
        # 'n_triggers': n_triggers
    })

    layout = results.generate_result_table(search_results)

    return [state, layout]
Esempio n. 20
0
def learn(env,
          q_func,
          lr=5e-4,
          max_timesteps=100000,
          buffer_size=50000,
          exploration_fraction=0.1,
          exploration_final_eps=0.02,
          train_freq=1,
          batch_size=32,
          print_freq=100,
          checkpoint_freq=10000,
          learning_starts=1000,
          gamma=1.0,
          target_network_update_freq=500,
          prioritized_replay=False,
          prioritized_replay_alpha=0.6,
          prioritized_replay_beta0=0.4,
          prioritized_replay_beta_iters=None,
          prioritized_replay_eps=1e-6,
          param_noise=False,
          callback=None):
    """Train a deepq model.

    Parameters
    -------
    env: gym.Env
        environment to train on
    q_func: (tf.Variable, int, str, bool) -> tf.Variable
        the model that takes the following inputs:
            observation_in: object
                the output of observation placeholder
            num_actions: int
                number of actions
            scope: str
            reuse: bool
                should be passed to outer variable scope
        and returns a tensor of shape (batch_size, num_actions) with values of every action.
    lr: float
        learning rate for adam optimizer
    max_timesteps: int
        number of env steps to optimizer for
    buffer_size: int
        size of the replay buffer
    exploration_fraction: float
        fraction of entire training period over which the exploration rate is annealed
    exploration_final_eps: float
        final value of random action probability
    train_freq: int
        update the model every `train_freq` steps.
        set to None to disable printing
    batch_size: int
        size of a batched sampled from replay buffer for training
    print_freq: int
        how often to print out training progress
        set to None to disable printing
    checkpoint_freq: int
        how often to save the model. This is so that the best version is restored
        at the end of the training. If you do not wish to restore the best version at
        the end of the training set this variable to None.
    learning_starts: int
        how many steps of the model to collect transitions for before learning starts
    gamma: float
        discount factor
    target_network_update_freq: int
        update the target network every `target_network_update_freq` steps.
    prioritized_replay: True
        if True prioritized replay buffer will be used.
    prioritized_replay_alpha: float
        alpha parameter for prioritized replay buffer
    prioritized_replay_beta0: float
        initial value of beta for prioritized replay buffer
    prioritized_replay_beta_iters: int
        number of iterations over which beta will be annealed from initial value
        to 1.0. If set to None equals to max_timesteps.
    prioritized_replay_eps: float
        epsilon to add to the TD errors when updating priorities.
    callback: (locals, globals) -> None
        function called at every steps with state of the algorithm.
        If callback returns true training stops.

    Returns
    -------
    act: ActWrapper
        Wrapper over act function. Adds ability to save it and load it.
        See header of baselines/deepq/categorical.py for details on the act function.
    """
    # Create all the functions necessary to train the model

    sess = tf.Session()
    sess.__enter__()

    # capture the shape outside the closure so that the env object is not serialized
    # by cloudpickle when serializing make_obs_ph
    observation_space_shape = env.observation_space.shape

    def make_obs_ph(name):
        return BatchInput(observation_space_shape, name=name)

    act, train, update_target, debug = deepq.build_train(
        make_obs_ph=make_obs_ph,
        q_func=q_func,
        num_actions=env.action_space.n,
        optimizer=tf.train.AdamOptimizer(learning_rate=lr),
        gamma=gamma,
        grad_norm_clipping=10,
        param_noise=param_noise)

    act_params = {
        'make_obs_ph': make_obs_ph,
        'q_func': q_func,
        'num_actions': env.action_space.n,
    }

    act = ActWrapper(act, act_params)

    # Create the replay buffer
    if prioritized_replay:
        replay_buffer = PrioritizedReplayBuffer(buffer_size,
                                                alpha=prioritized_replay_alpha)
        if prioritized_replay_beta_iters is None:
            prioritized_replay_beta_iters = max_timesteps
        beta_schedule = LinearSchedule(prioritized_replay_beta_iters,
                                       initial_p=prioritized_replay_beta0,
                                       final_p=1.0)
    else:
        replay_buffer = ReplayBuffer(buffer_size)
        beta_schedule = None
    # Create the schedule for exploration starting from 1.
    exploration = LinearSchedule(schedule_timesteps=int(exploration_fraction *
                                                        max_timesteps),
                                 initial_p=1.0,
                                 final_p=exploration_final_eps)

    # Initialize the parameters and copy them to the target network.
    U.initialize()
    update_target()

    episode_rewards = [0.0]
    saved_mean_reward = None
    obs = env.reset()
    reset = True
    with tempfile.TemporaryDirectory() as td:
        model_saved = False
        model_file = os.path.join(td, "model")
        for t in range(max_timesteps):
            if callback is not None:
                if callback(locals(), globals()):
                    break
            # Take action and update exploration to the newest value
            kwargs = {}
            if not param_noise:
                update_eps = exploration.value(t)
                update_param_noise_threshold = 0.
            else:
                update_eps = 0.
                # Compute the threshold such that the KL divergence between perturbed and non-perturbed
                # policy is comparable to eps-greedy exploration with eps = exploration.value(t).
                # See Appendix C.1 in Parameter Space Noise for Exploration, Plappert et al., 2017
                # for detailed explanation.
                update_param_noise_threshold = -np.log(1. - exploration.value(
                    t) + exploration.value(t) / float(env.action_space.n))
                kwargs['reset'] = reset
                kwargs[
                    'update_param_noise_threshold'] = update_param_noise_threshold
                kwargs['update_param_noise_scale'] = True
            action = act(np.array(obs)[None], update_eps=update_eps,
                         **kwargs)[0]
            env_action = action
            reset = False
            new_obs, rew, done, _ = env.step(env_action)
            # Store transition in the replay buffer.
            replay_buffer.add(obs, action, rew, new_obs, float(done))
            obs = new_obs

            episode_rewards[-1] += rew
            if done:
                obs = env.reset()
                episode_rewards.append(0.0)
                reset = True

            if t > learning_starts and t % train_freq == 0:
                # Minimize the error in Bellman's equation on a batch sampled from replay buffer.
                if prioritized_replay:
                    experience = replay_buffer.sample(
                        batch_size, beta=beta_schedule.value(t))
                    (obses_t, actions, rewards, obses_tp1, dones, weights,
                     batch_idxes) = experience
                else:
                    obses_t, actions, rewards, obses_tp1, dones = replay_buffer.sample(
                        batch_size)
                    weights, batch_idxes = np.ones_like(rewards), None
                td_errors = train(obses_t, actions, rewards, obses_tp1, dones,
                                  weights)
                if prioritized_replay:
                    new_priorities = np.abs(td_errors) + prioritized_replay_eps
                    replay_buffer.update_priorities(batch_idxes,
                                                    new_priorities)

            if t > learning_starts and t % target_network_update_freq == 0:
                # Update target network periodically.
                update_target()

            mean_100ep_reward = round(np.mean(episode_rewards[-101:-1]), 1)
            num_episodes = len(episode_rewards)
            if done and print_freq is not None and len(
                    episode_rewards) % print_freq == 0:
                logger.record_tabular("steps", t)
                logger.record_tabular("episodes", num_episodes)
                logger.record_tabular("mean 100 episode reward",
                                      mean_100ep_reward)
                logger.record_tabular("% time spent exploring",
                                      int(100 * exploration.value(t)))
                logger.dump_tabular()

            if (checkpoint_freq is not None and t > learning_starts
                    and num_episodes > 100 and t % checkpoint_freq == 0):
                if saved_mean_reward is None or mean_100ep_reward > saved_mean_reward:
                    if print_freq is not None:
                        logger.log(
                            "Saving model due to mean reward increase: {} -> {}"
                            .format(saved_mean_reward, mean_100ep_reward))
                    save_state(model_file)
                    model_saved = True
                    saved_mean_reward = mean_100ep_reward
        if model_saved:
            if print_freq is not None:
                logger.log("Restored model with mean reward: {}".format(
                    saved_mean_reward))
            load_state(model_file)

    return act
Esempio n. 21
0
def go(seed):
    tf.random.set_seed(seed)
    env = gym.make("FetchPickAndPlace-v1")
    env.seed(seed)

    test_set = []
    for i in range(TEST_EPS):
        state = env.reset()
        state, goal = utils.save_state(env)
        test_set.append((state, goal))

    states, actions = get_experience(INITIAL_TRAIN_EPS, env)
    print("Normal states, actions ", len(states), len(actions))
    file.write("Normal states, actions " + str(len(states)) + str(len(actions)))
    net = model.BCModelDropout(states[0].shape[0], actions[0].shape[0], BC_HD, BC_HL, BC_LR)

    x = np.array(states)
    xm = x.mean()
    xs = x.std()
    x = (x - x.mean())/x.std()

    a = np.array(actions)
    am = a.mean()
    ast = a.std()
    a = (a - a.mean())/a.std()

    net.train(x, a, BC_BS, BC_EPS)
    obs_dim, act_dim = len(x[0]), len(a[0])
    norm = Normalizer(obs_dim, act_dim).fit(x[:-1], a[:-1], x[1:])

    dyn = NNDynamicsModel(obs_dim, act_dim, 128, norm, 64, 500, 3e-4)
    dyn.fit({"states": x[:-1], "acts" : a[:-1], "next_states" : x[1:]}, plot = False)

    ae_x = AE(31, AE_HD, AE_HL, AE_LR)
    #ae = RandomNetwork(1, AE_HD, AE_HL, AE_LR)

    ae_x.train(x, AE_BS, AE_EPS)
    ae = FutureUnc(net, dyn, ae_x, steps = 3)

    tot_error_trainset = 0
    for el in x:
        error = ae.error(el.reshape((1,-1)))
        tot_error_trainset+=error
    tot_error_trainset/=len(x)

    print("Average error on train set", tot_error_trainset)
    file.write(str("Average error on train set") +str(tot_error_trainset))

    # I try to estimate the error on full train trajectories by multiplying
    # the average by len(dataset)/num_episodes, that is basically
    # the average episode lenght.

    if FULL_TRAJ:
        tot_error_train_fulltraj = 0
        avg_ep_lenght = len(x)/INITIAL_TRAIN_EPS
        tot_error_train_fulltraj = tot_error_trainset*avg_ep_lenght
    #    tot_error_trainset = tot_error_train_fulltraj

        print("Average full trajectory error on train set", tot_error_train_fulltraj)
        file.write(str("Average full trajectory error on train set") + str(tot_error_train_fulltraj))
    succ, fail, error_avg_s, error_avg_f, succ_list, fail_list = test(net, ae, test_set, env, xm, xs, am, ast, fulltraj = FULL_TRAJ, render = RENDER_TEST)
    print("Active learning results ", seed, " : ", succ, fail, "avg error on succ trails: ", error_avg_s, "on fail: ", error_avg_f, "std on succ:", np.std(succ_list), "on fail:", np.std(fail_list))
    file.write(str("Active learning results ") + str(seed) +  str(" : ") + str(succ) + str(fail) + str(error_avg_s) + str(error_avg_f) + str(np.std(succ_list)) +  str(np.std(fail_list)))

  #  file.write(str("Active learning results " + str(seed) + " : " + str(result_t)))



    succ_tp, succ_fp, succ_tn, succ_fn = predict(net, ae, test_set, env, xm, xs, am, ast, tot_error_trainset)

    #change to consider failures
    succ_tp, succ_fp, succ_tn, succ_fn = succ_tn, succ_fn, succ_tp, succ_fp
    fail, succ = succ, fail

    print("succ tp, fp, tn, fn", succ_tp, succ_fp, succ_tn, succ_fn)
    precision = (succ_tp/(succ_tp+succ_fp + 0.001))
    recall = (succ_tp/(succ_tp+succ_fn))
    print("F1 score", (2*precision*recall/(precision + recall)))
    print("F1 for all positives",  (2*(succ/(succ + fail))*1/((succ/(succ + fail)) + 1)))
    return (2*precision*recall/(precision + recall)), 2*(succ/(succ + fail))*1/((succ/(succ + fail)) + 1)
def main():

    ## config
    global args
    args = parser.parse_args()

    with open(args.config) as f:
        config = yaml.load(f)

    for k,v in config.items():    
        if isinstance(v, dict):
            argobj = ArgObj()
            setattr(args, k, argobj)
            for kk,vv in v.items():
                setattr(argobj, kk, vv)
        else:
            setattr(args, k, v)
    args.ngpu = len(args.gpus.split(','))

    ## asserts
    assert args.model.backbone in model_names, "available backbone names: {}".format(model_names)
    num_tasks = len(args.train.data_root)
    assert(num_tasks == len(args.train.loss_weight))
    assert(num_tasks == len(args.train.batch_size))
    assert(num_tasks == len(args.train.data_list))
    #assert(num_tasks == len(args.train.data_meta))
    if args.val.flag:
        assert(num_tasks == len(args.val.batch_size))
        assert(num_tasks == len(args.val.data_root))
        assert(num_tasks == len(args.val.data_list))
        #assert(num_tasks == len(args.val.data_meta))

    ## mkdir
    if not hasattr(args, 'save_path'):
        args.save_path = os.path.dirname(args.config)
    if not os.path.isdir('{}/checkpoints'.format(args.save_path)):
        os.makedirs('{}/checkpoints'.format(args.save_path))
    if not os.path.isdir('{}/logs'.format(args.save_path)):
        os.makedirs('{}/logs'.format(args.save_path))
    if not os.path.isdir('{}/events'.format(args.save_path)):
        os.makedirs('{}/events'.format(args.save_path))

    ## create dataset
    if not (args.extract or args.evaluate): # train + val
        for i in range(num_tasks):
            args.train.batch_size[i] *= args.ngpu

        #train_dataset = [FaceDataset(args, idx, 'train') for idx in range(num_tasks)]
        train_dataset = [FileListLabeledDataset(
            args.train.data_list[i], args.train.data_root[i],
            transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.Resize(args.model.input_size),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),]),
            memcached=args.memcached,
            memcached_client=args.memcached_client) for i in range(num_tasks)]
        args.num_classes = [td.num_class for td in train_dataset]
        train_longest_size = max([int(np.ceil(len(td) / float(bs))) for td, bs in zip(train_dataset, args.train.batch_size)])
        train_sampler = [GivenSizeSampler(td, total_size=train_longest_size * bs, rand_seed=args.train.rand_seed) for td, bs in zip(train_dataset, args.train.batch_size)]
        train_loader = [DataLoader(
            train_dataset[k], batch_size=args.train.batch_size[k], shuffle=False,
            num_workers=args.workers, pin_memory=False, sampler=train_sampler[k]) for k in range(num_tasks)]
        assert(all([len(train_loader[k]) == len(train_loader[0]) for k in range(num_tasks)]))

        if args.val.flag:
            for i in range(num_tasks):
                args.val.batch_size[i] *= args.ngpu
    
            #val_dataset = [FaceDataset(args, idx, 'val') for idx in range(num_tasks)]
            val_dataset = [FileListLabeledDataset(
                args.val.data_list[i], args.val.data_root[i],
                transforms.Compose([
                    transforms.Resize(args.model.input_size),
                    transforms.ToTensor(),
                    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),]),
                memcached=args.memcached,
                memcached_client=args.memcached_client) for idx in range(num_tasks)]
            
            val_longest_size = max([int(np.ceil(len(vd) / float(bs))) for vd, bs in zip(val_dataset, args.val.batch_size)])
            val_sampler = [GivenSizeSampler(vd, total_size=val_longest_size * bs, sequential=True) for vd, bs in zip(val_dataset, args.val.batch_size)]
            val_loader = [DataLoader(
                val_dataset[k], batch_size=args.val.batch_size[k], shuffle=False,
                num_workers=args.workers, pin_memory=False, sampler=val_sampler[k]) for k in range(num_tasks)]
            assert(all([len(val_loader[k]) == len(val_loader[0]) for k in range(num_tasks)]))

    if args.test.flag or args.evaluate: # online or offline evaluate
        args.test.batch_size *= args.ngpu
        test_dataset = []
        for tb in args.test.benchmark:
            if tb == 'megaface':
                test_dataset.append(FileListDataset(args.test.megaface_list,
                    args.test.megaface_root, transforms.Compose([
                    transforms.Resize(args.model.input_size),
                    transforms.ToTensor(),
                    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),])))
            else:
                test_dataset.append(BinDataset("{}/{}.bin".format(args.test.test_root, tb),
                    transforms.Compose([
                    transforms.Resize(args.model.input_size),
                    transforms.ToTensor(),
                    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
                    ])))
        test_sampler = [GivenSizeSampler(td,
            total_size=int(np.ceil(len(td) / float(args.test.batch_size)) * args.test.batch_size),
            sequential=True, silent=True) for td in test_dataset]
        test_loader = [DataLoader(
            td, batch_size=args.test.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=False, sampler=ts)
            for td, ts in zip(test_dataset, test_sampler)]

    if args.extract: # feature extraction
        args.extract_info.batch_size *= args.ngpu
#        extract_dataset = FaceDataset(args, 0, 'extract')
        extract_dataset = FileListDataset(
            args.extract_info.data_list, args.extract_info.data_root,
            transforms.Compose([
                transforms.Resize(args.model.input_size),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),]),
            memcached=args.memcached,
            memcached_client=args.memcached_client)
        extract_sampler = GivenSizeSampler(
            extract_dataset, total_size=int(np.ceil(len(extract_dataset) / float(args.extract_info.batch_size)) * args.extract_info.batch_size), sequential=True)
        extract_loader = DataLoader(
            extract_dataset, batch_size=args.extract_info.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=False, sampler=extract_sampler)


    ## create model
    log("Creating model on [{}] gpus: {}".format(args.ngpu, args.gpus))
    if args.evaluate or args.extract:
        args.num_classes = None
    model = models.MultiTaskWithLoss(backbone=args.model.backbone, num_classes=args.num_classes, feature_dim=args.model.feature_dim, spatial_size=args.model.input_size, arc_fc=args.model.arc_fc, feat_bn=args.model.feat_bn)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
    model = nn.DataParallel(model)
    model.cuda()
    cudnn.benchmark = True

    ## criterion and optimizer
    optimizer = torch.optim.SGD(model.parameters(), args.train.base_lr,
                                momentum=args.train.momentum,
                                weight_decay=args.train.weight_decay)

    ## resume / load model
    start_epoch = 0
    count = [0]
    if args.load_path:
        assert os.path.isfile(args.load_path), "File not exist: {}".format(args.load_path)
        if args.resume:
            checkpoint = load_state(args.load_path, model, optimizer)
            start_epoch = checkpoint['epoch']
            count[0] = checkpoint['count']
        else:
            load_state(args.load_path, model)

    ## offline evaluate
    if args.evaluate:
        for tb, tl, td in zip(args.test.benchmark, test_loader, test_dataset):
            evaluation(tl, model, num=len(td),
                       outfeat_fn="{}_{}.bin".format(args.load_path[:-8], tb),
                       benchmark=tb)
        return

    ## feature extraction
    if args.extract:
        extract(extract_loader, model, num=len(extract_dataset), output_file="{}_{}.bin".format(args.load_path[:-8], args.extract_info.data_name))
        return

    ######################## train #################
    ## lr scheduler
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.train.lr_decay_steps, gamma=args.train.lr_decay_scale, last_epoch=start_epoch-1)

    ## logger
    logging.basicConfig(filename=os.path.join('{}/logs'.format(args.save_path), 'log-{}-{:02d}-{:02d}_{:02d}:{:02d}:{:02d}.txt'.format(
        datetime.today().year, datetime.today().month, datetime.today().day,
        datetime.today().hour, datetime.today().minute, datetime.today().second)),
        level=logging.INFO)
    tb_logger = SummaryWriter('{}/events'.format(args.save_path))

    ## initial validate
    if args.val.flag:
        validate(val_loader, model, start_epoch, args.train.loss_weight, len(train_loader[0]), tb_logger)

    ## initial evaluate
    if args.test.flag and args.test.initial_test:
        log("*************** evaluation epoch [{}] ***************".format(start_epoch))
        for tb, tl, td in zip(args.test.benchmark, test_loader, test_dataset):
            res = evaluation(tl, model, num=len(td),
                             outfeat_fn="{}/checkpoints/ckpt_epoch_{}_{}.bin".format(
                             args.save_path, start_epoch, tb),
                             benchmark=tb)
            tb_logger.add_scalar(tb, res, start_epoch)

    ## training loop
    for epoch in range(start_epoch, args.train.max_epoch):
        lr_scheduler.step()
        for ts in train_sampler:
            ts.set_epoch(epoch)
        # train for one epoch
        train(train_loader, model, optimizer, epoch, args.train.loss_weight, tb_logger, count)
        # save checkpoint
        save_state({
            'epoch': epoch + 1,
            'arch': args.model.backbone,
            'state_dict': model.state_dict(),
            'optimizer' : optimizer.state_dict(),
            'count': count[0]
        }, args.save_path + "/checkpoints/ckpt_epoch", epoch + 1, is_last=(epoch + 1 == args.train.max_epoch))

        # validate
        if args.val.flag:
            validate(val_loader, model, epoch, args.train.loss_weight, len(train_loader[0]), tb_logger, count)
        # online evaluate
        if args.test.flag and ((epoch + 1) % args.test.interval == 0 or epoch + 1 == args.train.max_epoch):
            log("*************** evaluation epoch [{}] ***************".format(epoch + 1))
            for tb, tl, td in zip(args.test.benchmark, test_loader, test_dataset):
                res = evaluation(tl, model, num=len(td),
                                 outfeat_fn="{}/checkpoints/ckpt_epoch_{}_{}.bin".format(
                                 args.save_path, epoch + 1, tb),
                                 benchmark=tb)
                tb_logger.add_scalar(tb, res, start_epoch)
Esempio n. 23
0
criterion = nn.CrossEntropyLoss()
optimizer = SGD(net.parameters(),
                lr=args.lr,
                momentum=args.mom,
                weight_decay=args.wd)

## Training
for label_batch_id in range(class_batch_num):
    subtrainset = tf.get_subset(class_batch_list[label_batch_id, :], trainset)
    trainloader = DataLoader(subtrainset,
                             batch_size=args.bs,
                             drop_last=True,
                             num_workers=4)
    print("training starts on label batch:{}".format(label_batch_id))
    os.makedirs(
        os.path.join(model_dir, 'checkpoints',
                     'labelbatch{}'.format(label_batch_id)))
    for epoch in range(args.epo):
        lr_schedule(epoch, optimizer)
        for step, (batch_imgs, batch_lbls) in enumerate(trainloader):
            features = net(batch_imgs.cuda())
            loss = criterion(features, batch_lbls.cuda())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            utils.save_state(model_dir, label_batch_id, epoch, step,
                             loss.item())
        utils.save_ckpt(model_dir, net, epoch, label_batch_id)
print("training complete.")
Esempio n. 24
0
def loop(hvac_state, statefile):
    info("Subscribing to topics")
    for val in Topics().calls.values():
        hvac_state.client.subscribe(val)

    info("Preparing main loop")
    previous_state = load_state(statefile)
    hvac_state.aux = previous_state['aux']
    hvac_state.toggle = previous_state['toggle']
    hvac_state.mode = previous_state['mode']
    last_fetch_time = hvac_state.fetch_hvac_state()
    previous_state = {
        'aux': hvac_state.aux,
        'action': hvac_state.action,
        'hvac_code': hvac_state.hvac_code,
        'mode': hvac_state.mode,
        'status': hvac_state.status,
        'toggle': hvac_state.toggle,
    }
    debug(previous_state)
    hvac_state.publish('aux', previous_state['aux'])
    hvac_state.publish('mode', previous_state['mode'])
    hvac_state.publish('toggle', previous_state['toggle'])
    hvac_state.publish('action', previous_state['action'])
    hvac_state.publish('status', previous_state['status'])

    info("Starting main loop")
    while True:
        debug("running mqtt loop")
        hvac_state.client.loop(1)

        # Publish changes
        if previous_state['hvac_code'] == hvac_state.hvac_code:
            if previous_state['aux'] != hvac_state.aux:
                aux = hvac_state.aux
                hvac_state.publish('aux', aux)
                previous_state['aux'] = aux
            if previous_state['mode'] != hvac_state.mode:
                mode = hvac_state.mode
                hvac_state.publish('mode', mode)
                previous_state['mode'] = mode
            if previous_state['toggle'] != hvac_state.toggle:
                toggle = hvac_state.toggle
                hvac_state.publish('toggle', toggle)
                previous_state['toggle'] = toggle
            if previous_state['action'] != hvac_state.action:
                action = hvac_state.action
                hvac_state.publish('action', action)
                previous_state['action'] = action
            if previous_state['status'] != hvac_state.status:
                status = hvac_state.status
                hvac_state.publish('status', status)
                previous_state['status'] = status
            # refresh if needed
            if datetime.now() - last_fetch_time >= timedelta(seconds=2):
                last_fetch_time = hvac_state.fetch_hvac_state()

        else:
            # send new state to hvac
            hc = hvac_state.hvac_code
            info("Sending new state to hvac: %s", hc)
            hvac_state.change_hvac_state(hc)
            last_fetch_time = hvac_state.fetch_hvac_state()
            previous_state['hvac_code'] = hvac_state.hvac_code
            save_state(statefile, previous_state)
Esempio n. 25
0
def mine(checkpoint_num, matches, discovered_summonerIds, discovered_matchIds, g, max_hop, bfs_queue, hop):
    """
    Mine using Breadth First Search algorithm.

    Returns:
        no return
    """
    loop_count = 0  # used to track save points

    while (len(bfs_queue)>0) and (hop <= max_hop):

        loop_count += 1

        # Dequeue next summonerId
        next_item_in_queue = bfs_queue.popleft()
        summonerId = next_item_in_queue['summonerId']
        hop = next_item_in_queue['hop']

        # Get match ids for current node
        matchIds = rg.get_matchIds_by_summoner(summonerId)
        num_matches = len(matchIds)

        print ""
        print "SummonerId %d at hop %d has %d matches." % (summonerId, hop, num_matches)

        # Loop through all matches and add all members to the network
        for i, matchId in enumerate(matchIds):
            print "\rMatch %d out of %d [%0.1f%%]" % (i+1, num_matches, (i+1)/float(num_matches)*100),

            if matchId in discovered_matchIds:  # skip loop if match info already retreived
                continue

            # Get full match data and extract team members
            try:
                match = rg.get_match(matchId)
            except:
                print ""
                print "Error, skipping matchId: %d" % matchId
                continue

            team_memberIds = rg.get_summonerIds_by_match(match=match, team_with=summonerId)

            # Add team members to graph as a clique, with matchId info on edges
            g = algos.add_clique_with_weights(g, team_memberIds, edge_attr={'matchId': matchId})

            # Add new memberIds to queue
            for i in team_memberIds:
                if i not in discovered_summonerIds:
                    bfs_queue.append({'summonerId': i, 'hop': hop+1})
                    discovered_summonerIds[i] = True

            # Add match data to matches dict
            match = utils.list_of_dict_to_dict([match], 'matchId')
            matches.update(match)
            discovered_matchIds[matchId] = True

            # Sleep to stay under the API data rate limit
            time.sleep(TIME_SLEEP)

        if loop_count % CHECKPOINT_INTERVAL == 0:
            # Save data every CHECKPOINT_INTERVAL number of summonerIds
            checkpoint_num += 1
            utils.save_state(checkpoint_num, matches, discovered_summonerIds, discovered_matchIds, g, max_hop, bfs_queue, hop)
            return ""
Esempio n. 26
0
        def callback(locals, globals):
            if that.method != "ddpg":
                if load_policy is not None and locals[iter_name] == 0:
                    # noinspection PyBroadException
                    try:
                        utils.load_state(load_policy)
                        if MPI.COMM_WORLD.Get_rank() == 0:
                            logger.info("Loaded policy network weights from %s." % load_policy)
                            # save TensorFlow summary (contains at least the graph definition)
                    except:
                        logger.error("Failed to load policy network weights from %s." % load_policy)
                if MPI.COMM_WORLD.Get_rank() == 0 and locals[iter_name] == 0:
                    _ = tf.summary.FileWriter(folder, tf.get_default_graph())
            if MPI.COMM_WORLD.Get_rank() == 0 and locals[iter_name] % save_every == 0:
                print('Saving video and checkpoint for policy at iteration %i...' %
                      locals[iter_name])
                ob = env.reset()
                images = []
                rewards = []
                max_reward = 1.  # if any reward > 1, we have to rescale
                lower_part = video_height // 5
                for i in range(episode_length):
                    if that.method == "ddpg":
                        ac, _ = locals['agent'].pi(ob, apply_noise=False, compute_Q=False)
                    elif that.method == "sql":
                        ac, _ = locals['policy'].get_action(ob)
                    elif isinstance(locals['pi'], GaussianMlpPolicy):
                        ac, _, _ = locals['pi'].act(np.concatenate((ob, ob)))
                    else:
                        ac, _ = locals['pi'].act(False, ob)
                    ob, rew, new, _ = env.step(ac)
                    images.append(render_frames(env))
                    if plot_rewards:
                        rewards.append(rew)
                        max_reward = max(rew, max_reward)
                    if new:
                        break

                orange = np.array([255, 163, 0])
                red = np.array([255, 0, 0])
                video = []
                width_factor = 1. / episode_length * video_width
                for i, imgs in enumerate(images):
                    for img in imgs:
                        img[-lower_part, :10] = orange
                        img[-lower_part, -10:] = orange
                        if episode_length < video_width:
                            p_rew_x = 0
                            for j, r in enumerate(rewards[:i]):
                                rew_x = int(j * width_factor)
                                if r < 0:
                                    img[-1:, p_rew_x:rew_x] = red
                                    img[-1:, p_rew_x:rew_x] = red
                                else:
                                    rew_y = int(r / max_reward * lower_part)
                                    img[-rew_y - 1:, p_rew_x:rew_x] = orange
                                    img[-rew_y - 1:, p_rew_x:rew_x] = orange
                                p_rew_x = rew_x
                        else:
                            for j, r in enumerate(rewards[:i]):
                                rew_x = int(j * width_factor)
                                if r < 0:
                                    img[-1:, rew_x] = red
                                    img[-1:, rew_x] = red
                                else:
                                    rew_y = int(r / max_reward * lower_part)
                                    img[-rew_y - 1:, rew_x] = orange
                                    img[-rew_y - 1:, rew_x] = orange
                    video.append(np.hstack(imgs))

                imageio.mimsave(
                    os.path.join(folder, "videos", "%s_%s_iteration_%i.mp4" %
                                 (that.environment, that.method, locals[iter_name])),
                    video,
                    fps=60)
                env.reset()

                if that.method != "ddpg":
                    utils.save_state(os.path.join(that.folder, "checkpoints", "%s_%i" %
                                                 (that.environment, locals[iter_name])))
Esempio n. 27
0
                            sampler=args.sampler,
                            batch_size=args.bs,
                            num_aug=args.aug)

criterion = MaximalCodingRateReduction(gam1=args.gam1,
                                       gam2=args.gam2,
                                       eps=args.eps)
optimizer = optim.SGD(net.parameters(),
                      lr=args.lr,
                      momentum=args.mom,
                      weight_decay=args.wd)
scheduler = lr_scheduler.MultiStepLR(optimizer, [30, 60], gamma=0.1)
utils.save_params(model_dir, vars(args))

## Training
for epoch in range(args.epo):
    for step, (batch_imgs, _, batch_idx) in enumerate(trainloader):
        batch_features = net(batch_imgs.cuda())
        loss, loss_empi, loss_theo = criterion(batch_features, batch_idx)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        utils.save_state(model_dir, epoch, step, loss.item(), *loss_empi,
                         *loss_theo)
        if step % 20 == 0:
            utils.save_ckpt(model_dir, net, epoch)
    scheduler.step()
    utils.save_ckpt(model_dir, net, epoch)
print("training complete.")
Esempio n. 28
0
            tqdm.tqdm.write("global_step: {}\tloss: {}".format(
                global_step, step_loss["loss"]))
            for k, v in step_loss.items():
                writer.add_scalar(k, v, global_step)

            # train state saving, the first sentence in the batch
            if global_step % snap_interval == 0:
                save_state(
                    state_dir,
                    writer,
                    global_step,
                    mel_input=downsampled_mel_specs,
                    mel_output=mel_outputs,
                    lin_input=lin_specs,
                    lin_output=linear_outputs,
                    alignments=alignments,
                    win_length=win_length,
                    hop_length=hop_length,
                    min_level_db=min_level_db,
                    ref_level_db=ref_level_db,
                    power=power,
                    n_iter=n_iter,
                    preemphasis=preemphasis,
                    sample_rate=sample_rate)

            # evaluation
            if global_step % eval_interval == 0:
                sentences = [
                    "Scientists at the CERN laboratory say they have discovered a new particle.",
                    "There's a way to measure the acute emotional intelligence that has never gone out of style.",
                    "President Trump met with other leaders at the Group of 20 conference.",
Esempio n. 29
0
def go(seed, file):

    tf.random.set_seed(seed)
    env = gym.make("FetchPickAndPlace-v1")
    env.seed(seed)
    np.random.seed(seed)
    test_set = []
    for i in range(TEST_EPS):
        state = env.reset()
        state = robot_reset(env)
        state, goal = utils.save_state(env)
        test_set.append((state, goal))

    states, actions = get_experience(INITIAL_TRAIN_EPS, env, RENDER_TEST)
    print("Normal states, actions ", len(states), len(actions))

    net = model.BCModel(states[0].shape[0], actions[0].shape[0], BC_HD, BC_HL, BC_LR, set_seed = seed)
    x = np.array(states)
    xm = x.mean()
    xs = x.std()
    x = (x - x.mean())/x.std()

    a = np.array(actions)
    am = a.mean()
    ast = a.std()
    a = (a - a.mean())/a.std()

    start = time.time()
    print("TEST")
    print(x.shape)
    net.train(x, a, BC_BS, BC_EPS)
    print("Training took:")
    print(time.time() - start)
    result_t = test(net, test_set, env, xm, xs, am, ast, RENDER_TEST)
    print("Normal learning results ", seed, " : ", result_t)
    file.write(str("Normal learning results " + str(seed) + " : " + str(result_t)))

    ## Active Learning Part ###
    tf.random.set_seed(seed)
    env = gym.make("FetchPickAndPlace-v1")
    env.seed(seed)
    np.random.seed(seed)

    states, actions = states[:math.floor(len(states)*ORG_TRAIN_SPLIT)], actions[:math.floor(len(actions)*ORG_TRAIN_SPLIT)]
    #Train behavior net on half data.
    net_hf = model.BCModel(states[0].shape[0], actions[0].shape[0], BC_HD, BC_HL, BC_LR, set_seed = seed)
    x = np.array(states)
    xm = x.mean()
    xs = x.std()
    x = (x - x.mean())/x.std()

    a = np.array(actions)
    am = a.mean()
    ast = a.std()
    a = (a - a.mean())/a.std()
    net_hf.train(x, a, BC_BS, BC_EPS*2)

    #get_experience(int(INITIAL_TRAIN_EPS*ORG_TRAIN_SPLIT), env)
    act_l_loops = math.ceil(((1.-ORG_TRAIN_SPLIT)*INITIAL_TRAIN_EPS)//ACTIVE_STEPS_RETRAIN)
    if act_l_loops == 0: act_l_loops+=1
    for i in range(act_l_loops):

        x = np.array(states)
        xm = x.mean()
        xs = x.std()
        x = (x - x.mean())/x.std()

        a = np.array(actions)
        am = a.mean()
        ast = a.std()
        a = (a - a.mean())/a.std()

        #if AE_RESTART: ae = DAE(31, AE_HD, AE_HL, AE_LR, set_seed = seed)
        #Reinitialize both everytime and retrain.
        norm = Normalizer(31, 4).fit(x[:-1], a[:-1], x[1:])

        dyn = NNDynamicsModel(31, 4, 128, norm, 64, 500, 3e-4)
        dyn.fit({"states": x[:-1], "acts" : a[:-1], "next_states" : x[1:]}, plot = False)

        ae_x = AE(31, AE_HD, AE_HL, AE_LR)
        #ae = RandomNetwork(1, AE_HD, AE_HL, AE_LR)

        ae_x.train(x, AE_BS, AE_EPS)
        ae = FutureUnc(net, dyn, ae_x, steps = 5)
        net_hf = model.BCModel(states[0].shape[0], actions[0].shape[0], BC_HD, BC_HL, BC_LR, set_seed = seed)

        start = time.time()

        net_hf.train(x, a, BC_BS, BC_EPS*2)
        avg_error = avg_ae_error(ae, x)

        print("Training took:")
        print(time.time() - start)

        for j in range(ACTIVE_STEPS_RETRAIN):
            #new_s, new_a = get_active_exp(env, ACTIVE_ERROR_THR, ae, xm, xs, RENDER_ACT_EXP, TAKE_MAX, MAX_ACT_STEPS)
            new_s, new_a = get_active_exp2(env, avg_error, net_hf, ae, xm, xs, am, ast, RENDER_TEST, TAKE_MAX, MAX_ACT_STEPS)
            #print("len new s ", len(new_s), " len new a ", len(new_a))
            states+=new_s
            actions+=new_a

    x = np.array(states)
    xm = x.mean()
    xs = x.std()
    x = (x - x.mean())/x.std()

    a = np.array(actions)
    am = a.mean()
    ast = a.std()
    a = (a - a.mean())/a.std()

    print("Active states, actions ", len(states), len(actions))

    net = model.BCModel(states[0].shape[0], actions[0].shape[0], BC_HD, BC_HL, BC_LR, set_seed = seed)
    net.train(x, a, BC_BS, BC_EPS)

    result_t = test(net, test_set, env, xm, xs, am, ast, RENDER_TEST)
    print("Active learning results ", seed, " : ", result_t)
    file.write(str("Active learning results " + str(seed) + " : " + str(result_t)))
Esempio n. 30
0
def run(cfg):
    num_epochs = cfg['num_epochs']
    batch_size = cfg['batch_size']
    epoch = 0

    # Dataset and loader
    mnist_train, mnist_test = datasets.get_shift_MNIST()
    train_loader = data.DataLoader(mnist_train, batch_size=batch_size,
                                   **cfg['dataloader'])
    test_loader = data.DataLoader(mnist_test, batch_size=16,
                                  **cfg['dataloader'])

    # Models
    model = models.CapsNet(route_iters=cfg['route_iters'],
                           with_reconstruction=cfg['loss']['with_reconstruction'])
    best_model = copy.deepcopy(model)

    # CUDA
    if cfg['use_cuda'] and torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            model = torch.nn.DataParallel(model)
        else:
            model = model.cuda()
    else:
        cfg['use_cuda'] = False

    # Optimizer and scheduler
    optim = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr'])
    # scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, cfg['optim']['exp_decay'], last_epoch=epoch)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optim, cfg['optim']['milestone'], cfg['optim']['step_decay'], last_epoch=-1)

    # Best metrics
    best_metrics = {'acc': 0.0, 'loss': float('inf')}

    # Load State Dict
    if cfg['checkpoint']['use_checkpoint'] and os.path.exists(os.path.join(cfg['checkpoint']['log_dir'], cfg['checkpoint']['model_path'])):
        epoch += 1
        print('Load Checkpoint')
        epoch, model, optim, best_model, best_metrics = utils.load_state(os.path.join(cfg['checkpoint']['log_dir'], cfg['checkpoint']['model_path']), model, optim, best_model, use_best=cfg['checkpoint']['use_best'])

    # init logger
    logger = SummaryWriter(cfg['checkpoint']['log_dir'], purge_step=epoch)

    # Loss
    criterion = capsule_loss(cfg['loss']['with_reconstruction'],
                             cfg['loss']['rc_loss_weight'])

    model = copy.deepcopy(best_model)

    for epoch in range(epoch, num_epochs):
        train_acc, train_loss = train(model, train_loader, optim, criterion, logger, epoch, cfg['use_cuda'])

        logger.add_scalar('lr', optim.param_groups[0]['lr'], epoch)
        logger.add_scalar('epoch_acc/train', train_acc, epoch)
        logger.add_scalar('epoch_loss/train', train_loss, epoch)
        print(f"Epoch {epoch} train accuracy {train_acc:.4f}")
        print(f"Epoch {epoch} train loss {train_loss:.4f}")
        test_acc, test_loss = test(model, test_loader, criterion, logger, epoch, cfg['use_cuda'])
        print(f"Epoch {epoch} test accuracy {test_acc:.4f}")
        print(f"Epoch {epoch} test loss {test_loss:.4f}")
        logger.add_scalar('epoch_acc/test', test_acc, epoch)
        logger.add_scalar('epoch_loss/test', test_loss, epoch)
        scheduler.step()
        # Save best metrics
        if best_metrics['acc'] < test_acc:
            best_model = model
            best_metrics['acc'] = test_acc
            best_metrics['loss'] = test_loss
        utils.save_state(os.path.join(cfg['checkpoint']['log_dir'],
                                      cfg['checkpoint']['model_path']),
                         model, optim, epoch, best_model, best_metrics)
    logger.flush()
Esempio n. 31
0
            if use_cuda:
                data, target = data.cuda(), target.cuda()
            with torch.no_grad():
                output = F.softmax(model(data), dim=1)
            res = output if i == 0 else torch.max(res, output)
        pred = res.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    validation_loss /= len(val_loader.dataset)
    accuracy = 100. * correct / len(val_loader.dataset)
    logger.info('\nValidation set: Accuracy: {}/{} ({:.0f}%)\n'.format(
        correct, len(val_loader.dataset), accuracy))

    return accuracy


start_epoch = state['start_epoch']
for epoch in range(start_epoch, start_epoch + args.epochs):
    train(epoch)
    accuracy = validation()

    is_best = accuracy > state['best_acc']

    state['start_epoch'] = epoch + 1
    state['best_acc'] = accuracy if is_best else state['best_acc']
    state['model_state'] = model.state_dict()
    state['optim_state'] = optimizer.state_dict()

    utils.save_state(state, is_best, exp_dir)
    logger.info('\nModel saved\n')
Esempio n. 32
0
def learn(
        *,
        policy_network,
        classifier_network,
        env,
        max_iters,
        timesteps_per_batch=1024,  # what to train on
        max_kl=0.001,
        cg_iters=10,
        gamma=0.99,
        lam=1.0,  # advantage estimation
        seed=None,
        entcoeff=0.0,
        cg_damping=1e-2,
        vf_stepsize=3e-4,
        vf_iters=3,
        expert_trajs_path='./expert_trajs',
        num_expert_trajs=500,
        data_subsample_freq=20,
        g_step=1,
        d_step=1,
        classifier_entcoeff=1e-3,
        num_particles=5,
        d_stepsize=3e-4,
        max_episodes=0,
        total_timesteps=0,  # time constraint
        callback=None,
        load_path=None,
        save_path=None,
        render=False,
        use_classifier_logsumexp=True,
        use_reward_logsumexp=False,
        use_svgd=True,
        **policy_network_kwargs):
    '''
    learn a policy function with TRPO algorithm
    
    Parameters:
    ----------

    network                 neural network to learn. Can be either string ('mlp', 'cnn', 'lstm', 'lnlstm' for basic types)
                            or function that takes input placeholder and returns tuple (output, None) for feedforward nets
                            or (output, (state_placeholder, state_output, mask_placeholder)) for recurrent nets

    env                     environment (one of the gym environments or wrapped via baselines.common.vec_env.VecEnv-type class

    timesteps_per_batch     timesteps per gradient estimation batch

    max_kl                  max KL divergence between old policy and new policy ( KL(pi_old || pi) )

    entcoeff                coefficient of policy entropy term in the optimization objective

    cg_iters                number of iterations of conjugate gradient algorithm

    cg_damping              conjugate gradient damping 

    vf_stepsize             learning rate for adam optimizer used to optimie value function loss

    vf_iters                number of iterations of value function optimization iterations per each policy optimization step

    total_timesteps         max number of timesteps

    max_episodes            max number of episodes
    
    max_iters               maximum number of policy optimization iterations

    callback                function to be called with (locals(), globals()) each policy optimization step

    load_path               str, path to load the model from (default: None, i.e. no model is loaded)

    **network_kwargs        keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network

    Returns:
    -------

    learnt model

    '''

    nworkers = MPI.COMM_WORLD.Get_size()
    if nworkers > 1:
        raise NotImplementedError
    rank = MPI.COMM_WORLD.Get_rank()

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.49)
    U.get_session(config=tf.ConfigProto(allow_soft_placement=True,
                                        gpu_options=gpu_options))

    policy = build_policy(env,
                          policy_network,
                          value_network='copy',
                          **policy_network_kwargs)
    set_global_seeds(seed)

    np.set_printoptions(precision=3)
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space

    ob = observation_placeholder(ob_space)
    with tf.variable_scope("pi"):
        pi = policy(observ_placeholder=ob)
    with tf.variable_scope("oldpi"):
        oldpi = policy(observ_placeholder=ob)

    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    entbonus = entcoeff * meanent

    vferr = tf.reduce_mean(tf.square(pi.vf - ret))

    ratio = tf.exp(pi.pd.logp(ac) -
                   oldpi.pd.logp(ac))  # advantage * pnew / pold
    surrgain = tf.reduce_mean(ratio * atarg)

    optimgain = surrgain + entbonus
    losses = [optimgain, meankl, entbonus, surrgain, meanent]
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]

    dist = meankl

    all_var_list = get_trainable_variables("pi")
    # var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("pol")]
    # vf_var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("vf")]
    var_list = get_pi_trainable_variables("pi")
    vf_var_list = get_vf_trainable_variables("pi")

    vfadam = MpiAdam(vf_var_list)

    get_flat = U.GetFlat(var_list)
    set_from_flat = U.SetFromFlat(var_list)
    klgrads = tf.gradients(dist, var_list)
    flat_tangent = tf.placeholder(dtype=tf.float32,
                                  shape=[None],
                                  name="flat_tan")
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        sz = U.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start + sz], shape))
        start += sz
    gvp = tf.add_n([
        tf.reduce_sum(g * tangent)
        for (g, tangent) in zipsame(klgrads, tangents)
    ])  #pylint: disable=E1111
    fvp = U.flatgrad(gvp, var_list)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(get_variables("oldpi"), get_variables("pi"))
        ])

    compute_losses = U.function([ob, ac, atarg], losses)
    compute_lossandgrad = U.function([ob, ac, atarg], losses +
                                     [U.flatgrad(optimgain, var_list)])
    compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
    compute_vflossandgrad = U.function([ob, ret],
                                       U.flatgrad(vferr, vf_var_list))

    D = build_classifier(env, classifier_network, num_particles,
                         classifier_entcoeff, use_classifier_logsumexp,
                         use_reward_logsumexp)
    grads_list, vars_list = D.get_grads_and_vars()

    if use_svgd:
        optimizer = SVGD(
            grads_list, vars_list,
            lambda: tf.train.AdamOptimizer(learning_rate=d_stepsize))
    else:
        optimizer = Ensemble(
            grads_list, vars_list,
            lambda: tf.train.AdamOptimizer(learning_rate=d_stepsize))

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='yellow'))
            tstart = time.time()
            yield
            print(
                colorize("done in %.3f seconds" % (time.time() - tstart),
                         color='blue'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= nworkers
        return out

    U.initialize()

    if rank == 0:
        saver = tf.train.Saver(var_list=get_variables("pi"), max_to_keep=10000)
        writer = FileWriter(os.path.join(save_path, 'logs'))
        stats = Statistics(
            scalar_keys=["average_return", "average_episode_length"])

    if load_path is not None:
        # pi.load(load_path)
        saver.restore(U.get_session(), load_path)

    th_init = get_flat()
    MPI.COMM_WORLD.Bcast(th_init, root=0)
    set_from_flat(th_init)
    vfadam.sync()
    print("Init param sum", th_init.sum(), flush=True)

    # Prepare for rollouts
    # ----------------------------------------
    if load_path is not None:
        seg_gen = traj_segment_generator(pi,
                                         env,
                                         1,
                                         stochastic=False,
                                         render=render)
    else:
        seg_gen = traj_segment_generator(pi,
                                         env,
                                         timesteps_per_batch,
                                         stochastic=True,
                                         render=render)
    seg_gen_e = expert_traj_segment_generator(env, expert_trajs_path,
                                              data_subsample_freq,
                                              timesteps_per_batch,
                                              num_expert_trajs)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40)  # rolling buffer for episode rewards

    if sum([max_iters > 0, total_timesteps > 0, max_episodes > 0]) == 0:
        # nothing to be done
        return pi

    assert sum([max_iters>0, total_timesteps>0, max_episodes>0]) < 2, \
        'out of max_iters, total_timesteps, and max_episodes only one should be specified'

    while True:
        if callback: callback(locals(), globals())
        if total_timesteps and timesteps_so_far >= total_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        logger.log("********** Iteration %i ************" % iters_so_far)

        if iters_so_far % 500 == 0 and save_path is not None and load_path is None:
            fname = os.path.join(save_path, 'checkpoints', 'checkpoint')
            save_state(fname, saver, iters_so_far)

        with timed("sampling"):
            seg = seg_gen.__next__()

        if load_path is not None:
            iters_so_far += 1
            logger.record_tabular("EpRew", int(np.mean(seg["ep_true_rets"])))
            logger.record_tabular("EpLen", int(np.mean(seg["ep_lens"])))
            logger.dump_tabular()
            continue

        seg["rew"] = D.get_reward(seg["ob"], seg["ac"])

        add_vtarg_and_adv(seg, gamma, lam)

        ob, ac, ep_lens, atarg, tdlamret = seg["ob"], seg["ac"], seg[
            "ep_lens"], seg["adv"], seg["tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate

        if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret)
        if hasattr(pi, "rms"):
            pi.rms.update(ob)  # update running mean/std for policy

        args = seg["ob"], seg["ac"], atarg
        fvpargs = [arr[::5] for arr in args]

        def fisher_vector_product(p):
            return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p

        assign_old_eq_new()  # set old parameter values to new parameter values
        with timed("computegrad"):
            *lossbefore, g = compute_lossandgrad(*args)
        lossbefore = allmean(np.array(lossbefore))
        g = allmean(g)
        if np.allclose(g, 0):
            logger.log("Got zero gradient. not updating")
        else:
            with timed("cg"):
                stepdir = cg(fisher_vector_product,
                             g,
                             cg_iters=cg_iters,
                             verbose=rank == 0)
            assert np.isfinite(stepdir).all()
            shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
            lm = np.sqrt(shs / max_kl)
            # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
            fullstep = stepdir / lm
            expectedimprove = g.dot(fullstep)
            surrbefore = lossbefore[0]
            stepsize = 1.0
            thbefore = get_flat()
            for _ in range(10):
                thnew = thbefore + fullstep * stepsize
                set_from_flat(thnew)
                meanlosses = surr, kl, *_ = allmean(
                    np.array(compute_losses(*args)))
                improve = surr - surrbefore
                logger.log("Expected: %.3f Actual: %.3f" %
                           (expectedimprove, improve))
                if not np.isfinite(meanlosses).all():
                    logger.log("Got non-finite value of losses -- bad!")
                elif kl > max_kl * 1.5:
                    logger.log("violated KL constraint. shrinking step.")
                elif improve < 0:
                    logger.log("surrogate didn't improve. shrinking step.")
                else:
                    logger.log("Stepsize OK!")
                    break
                stepsize *= .5
            else:
                logger.log("couldn't compute a good step")
                set_from_flat(thbefore)
            if nworkers > 1 and iters_so_far % 20 == 0:
                paramsums = MPI.COMM_WORLD.allgather(
                    (thnew.sum(), vfadam.getflat().sum()))  # list of tuples
                assert all(
                    np.allclose(ps, paramsums[0]) for ps in paramsums[1:])

        for (lossname, lossval) in zip(loss_names, meanlosses):
            logger.record_tabular(lossname, lossval)

        with timed("vf"):

            for _ in range(vf_iters):
                for (mbob, mbret) in dataset.iterbatches(
                    (seg["ob"], seg["tdlamret"]),
                        include_final_partial_batch=False,
                        batch_size=1000):
                    g = allmean(compute_vflossandgrad(mbob, mbret))
                    vfadam.update(g, vf_stepsize)

        with timed("sample expert trajectories"):
            ob_a, ac_a, ep_lens_a = ob, ac, ep_lens
            seg_e = seg_gen_e.__next__()
            ob_e, ac_e, ep_lens_e = seg_e["ob"], seg_e["ac"], seg_e["ep_lens"]

        if hasattr(D, "rms"):
            obs = np.concatenate([ob_a, ob_e], axis=0)
            if isinstance(ac_space, spaces.Box):
                acs = np.concatenate([ac_a, ac_e], axis=0)
                D.rms.update(np.concatenate([obs, acs], axis=1))
            elif isinstance(ac_space, spaces.Discrete):
                D.rms.update(obs)
            else:
                raise NotImplementedError

        with timed("SVGD"):
            sess = tf.get_default_session()
            feed_dict = {
                D.Xs['a']: ob_a,
                D.As['a']: ac_a,
                D.Ls['a']: ep_lens_a,
                D.Xs['e']: ob_e,
                D.As['e']: ac_e,
                D.Ls['e']: ep_lens_e
            }
            for _ in range(d_step):
                sess.run(optimizer.update_op, feed_dict=feed_dict)

        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))

        lrlocal = (seg["ep_lens"], seg["ep_true_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)

        if rank == 0:
            logger.dump_tabular()
            stats.add_all_summary(
                writer,
                [np.mean(rewbuffer), np.mean(lenbuffer)], iters_so_far)
            rewbuffer.clear()
            lenbuffer.clear()

    return pi