Пример #1
0
def main():
    args = parse_arguments()

    if args.getmousepos:
        utils.getjoinposn()
    elif args.updatepass is not None:
        utils.update_pass(args.updatepass)
    elif args.append is not None:
        utils.append(args.append)
    elif args.changepass is not None:
        utils.change_pass(args.changepass)
    else:
        config = configparser.ConfigParser()
        config.read("data.ini")
        joinposn = config["VALUES"]["join"].split(", ")
        try:
            joinposn = [int(x) for x in joinposn]
        except ValueError:
            # fmt: off
            print("Oops, it looks like you haven't"
                  "set the positon of join button.")
            # fmt: on

        subject = utils.get_subject() if args.subject is None else args.subject
        zoom_id, zoom_pass = utils.get_credentials(subject)
        utils.auto_type(zoom_id, zoom_pass, joinposn)
Пример #2
0
def vote():
    if request.method == "POST" :
        
        vote = request.form["vote"]
        
        url = "http://gdata.youtube.com/feeds/api/videos?q=%s&max-results=1&v=2&alt=jsonc" % urllib.quote_plus(vote)
        result = simplejson.load(urllib.urlopen(url))
        video_id = result['data']['items'][0]['id']
        
        utils.append(utils.get_path(RADIO_ROOT, 'to_process_votes'), video_id)
        
    return render_template("vote.html")
def update_response(
    input_features: Optional[np.ndarray],
    normals: Optional[np.ndarray],
    preferences: Optional[np.ndarray],
    phi_A: np.ndarray,
    phi_B: np.ndarray,
    preference: int,
    outdir: Path,
):
    input_features = append(input_features, np.stack([phi_A, phi_B]))
    normals = append(normals, phi_A - phi_B)
    preferences = append(preferences, preference)
    np.save(outdir / "input_features.npy", input_features)
    np.save(outdir / "normals.npy", normals)
    np.save(outdir / "preferences.npy", preferences)
    return input_features, normals, preferences
    def create_posting_files(self, posting_dict, letter_word_mapping):
        """
        The function will get the mapping of every word to its first char (a: atom, assertive..)
        and will create the posting file for all the words of a certain character.
        param posting_dict: Dictionary of all the terms and their info to save in the posting files.
        param letter_word_mapping: Dictionary that map each word to its first char. For example:
        {'a':['atom','arg'], 'b':['bar']}
        """
        for char in letter_word_mapping:
            word_data_dict = {}
            if self.postings_data.get(char) is None:
                name = "SPECIALS"
                char_path = "SPECIALS"
            else:
                name = self.postings_data[char]['name']
                char_path = char
            for word in letter_word_mapping[char]:
                word_data_dict[word] = posting_dict[word]

            if word_data_dict:
                utils.append(word_data_dict, f"{self.postings_data[char_path]['path']}\\{name}")
    def forward(self, x, backward_on_y=False):
        out = F.relu(self.conv1(x))
        out = F.relu(self.conv2(out))
        out = F.relu(self.conv3(out))
        out = F.relu(self.conv4(out))
        out = out.reshape(-1, 1024)
        out = self.fc(out)

        ys_logits, ys_idx, ys, zs_mean, zs_logvar, zs = [], [], [], [], [], []
        ys_logits2 = []
        out2 = out
        for switch in self.fc_switched:
            out, y_logits, y_idx, y, z_mean, z_logvar, z = switch(out, backward_on_y=backward_on_y)
            U.append((ys_logits, ys_idx, ys, zs_mean, zs_logvar, zs),
                     (y_logits, y_idx, y, z_mean, z_logvar, z))

            if self.training:
                out2, y_logits2, _, _, _, _, _ = switch(out2, backward_on_y=backward_on_y)
                ys_logits2.append(y_logits2)
        out = F.relu(out)

        z2_mean = self.fc_mean(out)
        z2_logvar = self.fc_logvar(out)
        return z2_mean, z2_logvar, ys_logits, ys_logits2, ys_idx, ys, zs_mean, zs_logvar, zs
Пример #6
0
 def forward(self, y, large_z, context):  # train time
     y, lengths = append(truncate(y, 'eos'), 'sos')
     if self.word_drop > 0.:
         y = word_drop(y, self.word_drop)
     embedded = self.embedding(y)  # (B, l, 300)
     embedded = torch.cat(
         [embedded, context.repeat(1, embedded.size(1), 1)], dim=-1)
     packed = pack_padded_sequence(embedded, lengths, batch_first=True)
     init_hidden = self._transform_hidden(large_z)
     packed_output, _ = self.lstm(packed, init_hidden)
     total_length = embedded.size(1)
     output, _ = pad_packed_sequence(packed_output,
                                     batch_first=True,
                                     total_length=total_length)
     recon_logits = self.out(output)
     return recon_logits  # (B, L, vocab_size)
Пример #7
0
 def forward(self, orig, para, z):  # train time
     orig, orig_lengths = orig  # (B, l), (B,)
     orig = self.embedding(orig)  # (B, l, 300)
     orig_packed = pack_padded_sequence(orig,
                                        orig_lengths,
                                        batch_first=True)
     _, orig_hidden = self.lstm_orig(orig_packed)
     para, _ = append(truncate(para, 'eos'), 'sos')
     if self.word_drop > 0.:
         para = word_drop(para, self.word_drop)  # from Bowman's paper
     para = self.embedding(para)
     L = para.size(1)
     para_z = torch.cat([para, z.repeat(1, L, 1)],
                        dim=-1)  # (B, L, 1100+300)
     para_output, _ = self.lstm_para(para_z, orig_hidden)  # no packing
     logits = self.linear(para_output)
     return logits  # (B, L, vocab_size)
Пример #8
0
    def forward(self, z, l, x=None):
        """
        z: (B, 500)
        l: (B,)
        x: tuple of (B, L+1), (B,)
        """
        B = l.size(0)
        l_embed = self.attr_emb(l) # (B, 200)
        hidden = torch.cat([z, l_embed], dim=-1).unsqueeze(0) # (1, B, 700)

        if x is not None: # loss computation with teacher forcing
            x, lengths = append(truncate(x, 'eos'), 'sos')
            x_embed = self.emb(x) # (B, L+1, 300)
            packed_in = pack_padded_sequence(x_embed, lengths, batch_first=True)
            packed_out, _ = self.gru(packed_in, hidden)
            total_length = x.size(1)
            # (B, L, 700)
            hx, lengths = pad_packed_sequence(packed_out, batch_first=True,
                                                  total_length=total_length)
            output = self.out(hx)
            return (hx, lengths), (output, lengths) # (B, L+1, 700), (B,)
                                                    # (B, L+1, vocab), (B,)
        else: # sample y
            y = []
            hy = []
            input_ = l.new_full((B, 1), SOS_IDX)
            for t in range(MAXLEN):
                input_ = self.emb(input_) # (B, 1, 300)
                # output (B, 1, 700), hidden (1, B, 700)
                output, hidden = self.gru(input_, hidden)
                input_ = self._hard_sampling(self.out(output))
                hy.append(output)
                y.append(input_)
            input_ = l.new_full((B,1), EOS_IDX) # feed <eos> as last input,
            output, _ = self.gru(self.emb(input_), hidden)
            hy.append(output)
            y.append(input_) # append <eos> as last token

            hy = torch.cat(hy, dim=1)
            y = torch.cat(y, dim=1)
            hy, y, lengths = self._tighten(hy, y)
            #lengths = y.new_full((B,), MAXLEN+1)
            return (hy, lengths), (y, lengths) # (B, MAXLEN+1, 700), (B, MAXLEN+1), (B, )
Пример #9
0
 def forward_all_cells(self):
     """ move all agents in map one time step forward """
     agents_durations = self.durations[cp.arange(0, self.durations.shape[0]), self.current_state_ids].flatten()
     print(f'DEBUG: agents_durations.shape: {agents_durations.shape}, self.durations.shape: {self.durations.shape}, self.current_state_ids.shape: {self.current_state_ids.shape}')
     to_transit = (self.current_state_durations == agents_durations)
     self.current_state_durations += 1
     to_transit = self.agent_ids[to_transit]
     self.transit_states(to_transit)
     # Contamination at home by end of the period
     self.contaminate(self.agent_ids, self.home_cell_ids)
     # Update r and associated variables
     r = self.n_infected_period / self.n_diseased_period if self.n_diseased_period > 0 else 0
     r = cp.array([r])
     if self.verbose > 1:
         print(f'period {self.current_period}: r={r}')
     self.r_factors = append(self.r_factors, r)
     self.n_diseased_period = self.get_n_diseased()
     self.n_infected_period = 0
     #Move one period forward
     self.current_period += 1
Пример #10
0
    def contaminate(self, selected_agents, selected_cells):
        """ both arguments have same length. If an agent with sensitivity > 0 is in the same cell 
        than an agent with contagiousity > 0: possibility of contagion """
        t_start = time()
        i = 0
        t0 = time()
        selected_unsafeties = self.unsafeties[selected_cells]
        selected_agents = selected_agents.astype(cp.uint32)
        selected_states = self.current_state_ids[selected_agents]
        selected_contagiousities = self.unique_contagiousities[selected_states]
        selected_sensitivities = self.unique_sensitivities[selected_states]
        print(f'ttt first part contaminate: {time() - t0}')
        # Find cells where max contagiousity == 0 (no contagiousity can happen there)
        t0 = time()
        cont_sens = cp.multiply(selected_contagiousities, selected_sensitivities)
        print(f'ttt group max sensitivities: {time() - t0}')
        # Combine them
        if cp.max(cont_sens) == 0:
            return
        t0 = time()
        mask_zero = (cont_sens > 0)
        selected_agents = selected_agents[mask_zero]
        selected_contagiousities = selected_contagiousities[mask_zero]
        selected_sensitivities = selected_sensitivities[mask_zero]
        selected_cells = selected_cells[mask_zero]
        selected_unsafeties = selected_unsafeties[mask_zero]
        print(f'ttt mask zero all: {time() - t0}')
        
        # Compute proportion (contagious agent) / (non contagious agent) by cell
        t0 = time()
        _, n_contagious_by_cell = cp.unique(selected_cells[selected_contagiousities > 0], return_counts=True)
        _, n_non_contagious_by_cell = cp.unique(selected_cells[selected_contagiousities == 0], return_counts=True)
        print(f'ttt non contagious: {time() - t0}')
        i += 1
        t0 = time()
        p_contagious = cp.divide(n_contagious_by_cell, n_non_contagious_by_cell)

        n_selected_agents = selected_agents.shape[0]
        print(f'ttt p_contagious: {time() - t0}')
  
        if self.verbose > 1:
            print(f'{n_selected_agents} selected agents after removing cells with max sensitivity or max contagiousity==0')
        if n_selected_agents == 0:
            return
        # Find for each cell which agent has the max contagiousity inside (it will be the contaminating agent)
        t0 = time()
        max_contagiousities, mask_max_contagiousities = group_max(data=selected_contagiousities, groups=selected_cells) 
        print(f'ttt max contagious: {time() - t0}')
        t0 = time()
        infecting_agents = selected_agents[mask_max_contagiousities]
        selected_contagiousities = selected_contagiousities[mask_max_contagiousities]
        print(f'ttt mask max contagious: {time() - t0}')
        # Select agents that can be potentially infected ("pinfected") and corresponding variables
        t0 = time()
        pinfected_mask = (selected_sensitivities > 0)
        pinfected_agents = selected_agents[pinfected_mask]
        selected_sensitivities = selected_sensitivities[pinfected_mask]
        selected_unsafeties = selected_unsafeties[pinfected_mask]
        selected_cells = selected_cells[pinfected_mask]
        print(f'ttt p_infected_mask: {time() - t0}')

        # Group `selected_cells` and expand `infecting_agents` and `selected_contagiousities` accordingly
        # There is one and only one infecting agent by pinselected_agentsfected_cell so #`counts` == #`infecting_agents`
        t0 = time()
        _, inverse = cp.unique(selected_cells, return_inverse=True)
        print(f'ttt inverse select cell: {time() - t0}')
        # TODO: ACHTUNG: count repeat replace by inverse here
        t0 = time()
        infecting_agents = infecting_agents[inverse]
        selected_contagiousities = selected_contagiousities[inverse]
        p_contagious = p_contagious[inverse]
        print(f'ttt p_contagious inverse: {time() - t0}')
        # Compute contagions
        t0 = time()
        res = cp.multiply(selected_contagiousities, selected_sensitivities)
        res = cp.multiply(res, selected_unsafeties)
        print(f'ttt cp.multiply: {time() - t0}')
        # Modifiy probas contamination according to `p_contagious`
        t0 = time()
        mask_p = (p_contagious < 1)
        res[mask_p] = cp.multiply(res[mask_p], p_contagious[mask_p])
        res[~mask_p] = 1 - cp.divide(1 - res[~mask_p], p_contagious[~mask_p])
        print(f'ttt res mask p: {time() - t0}')

        t0 = time()
        draw = cp.random.uniform(size=infecting_agents.shape[0])
        draw = (draw < res)
        infecting_agents = infecting_agents[draw]
        infected_agents = pinfected_agents[draw]
        n_infected_agents = infected_agents.shape[0]
        print(f'ttt n_infected draw: {time() - t0}')
        if self.verbose > 1:
            print(f'Infecting and infected agents should be all different, are they? {((infecting_agents == infected_agents).sum() == 0)}')
            print(f'Number of infected agents: {n_infected_agents}')
        t0 = time()
        self.current_state_ids[infected_agents] = self.least_state_ids[infected_agents]
        self.current_state_durations[infected_agents] = 0
        self.n_infected_period += n_infected_agents
        self.infecting_agents = append(self.infecting_agents, infecting_agents)
        self.infected_agents = append(self.infected_agents, infected_agents)
        self.infected_periods = append(self.infected_periods, cp.multiply(cp.ones(n_infected_agents), self.current_period))
        print(f'ttt final: {time() - t0}')
        print(f'contaminate computed in {time() - t_start}')
Пример #11
0
        continue
    current.val_loss_delta = current.val_loss - parent.val_loss
    if current.val_loss_delta > 0:
        increases.append(current)
    if current.val_loss > node_worst.val_loss: node_worst = current
    if current.val_loss < node_best.val_loss: node_best = current
    total += 1

if total == 0: fail('No matching Nodes found!')

fraction = 100.0 * len(increases) / total
print('increases/total = %i / %i (%02.f%%)' %
      (len(increases), total, fraction))

file_increases = "increases-%s.data" % args.token
append(file_increases, "%i %5.1f" % (args.stage, fraction))

print('worst val_loss: ' + str(node_worst))
print('best  val_loss: ' + str(node_best))

print('DELTAS:')

increases.sort(key=Node.get_val_loss_delta)
stopped_early = 0
for i in increases:
    # print('%f %-14s %r' % (i.val_loss_delta, i.id, i.stopped_early))
    if i.stopped_early: stopped_early += 1


def print_delta(prefix, node):
    print(prefix, str(node), 'delta: %f' % node.val_loss_delta)
Пример #12
0
def append_validator_if_set(validators: List[callable], is_set: bool,
                            create_val: callable, compare: callable, limit,
                            template: str) -> List[callable]:
    return append(validators, create_val(compare, limit, template %
                                         str(limit))) if is_set else validators
def simulated(
    outdir: Path,
    criterion: Literal["information", "volume", "random"],
    termination_threshold: float,
    n_reward_samples: int,
    query_type: Literal["strict", "weak"] = "strict",
    equiv_size: Optional[float] = None,
    true_reward_path: Optional[Path] = None,
    continuous: bool = False,
    overwrite: bool = False,
    replicaitons: Optional[str] = None,
):
    """ Generates a test by eliciting from a human simulated by a ground truth reward. """
    if replicaitons is not None:
        replication_indices = parse_replications(replicaitons)
        if true_reward_path is not None:
            reward_dir, reward_name = make_reward_path(true_reward_path)
            Parallel(n_jobs=-2)(
                delayed(simulated)(
                    outdir=Path(outdir) / str(i),
                    criterion=criterion,
                    termination_threshold=termination_threshold,
                    n_reward_smaples=n_reward_samples,
                    query_type=query_type,
                    equiv_size=equiv_size,
                    true_reward_path=reward_dir / str(i) / reward_name,
                    continuous=continuous,
                    overwrite=overwrite,
                )
                for i in replication_indices
            )
        else:
            Parallel(n_jobs=-2)(
                delayed(simulated)(
                    outdir=Path(outdir) / str(i),
                    criterion=criterion,
                    termination_threshold=termination_threshold,
                    n_reward_smaples=n_reward_samples,
                    query_type=query_type,
                    equiv_size=equiv_size,
                    continuous=continuous,
                    overwrite=overwrite,
                )
                for i in replication_indices
            )
        exit()

    criterion, query_type, outdir = setup(criterion, query_type, outdir, delta=equiv_size)

    env = Driver()
    d = env.num_of_features

    if true_reward_path is not None:
        logging.info(f"Loading true reward from {true_reward_path}")
        true_reward = np.load(true_reward_path)
    else:
        logging.info("Randomly generating true reward")
        true_reward = np.random.normal(size=(4,))
        true_reward = true_reward / np.linalg.norm(true_reward)
        np.save(outdir / "true_reward.npy", true_reward)

    pickle.dump(
        {
            "criterion": criterion,
            "reward_iterations": n_reward_samples,
            "stop_thresh": termination_threshold,
            "query_type": query_type,
            "equiv_size": equiv_size,
            "continuous": continuous,
        },
        open(outdir / "flags.pkl", "wb"),
    )

    normals = load(outdir / "normals.npy", overwrite=overwrite)
    preferences = load(outdir / "preferences.npy", overwrite=overwrite)
    inputs = load(outdir / "inputs.npy", overwrite=overwrite)
    input_features = load(outdir / "input_features.npy", overwrite=overwrite)

    # If there is already data, feed it to the w_sampler to get the right posterior.
    w_sampler = Sampler(d)
    if inputs is not None and input_features is not None and preferences is not None:
        for (a_phi, b_phi), preference in zip(input_features, preferences):
            w_sampler.feed(a_phi, b_phi, [preference])

    score = np.inf
    try:
        while score >= termination_threshold:
            w_samples, delta_samples = w_sampler.sample_given_delta(
                sample_count=n_reward_samples, query_type=query_type, delta=equiv_size
            )

            input_A, input_B, score = run_algo(criterion, env, w_samples, delta_samples, continuous)
            logging.info(f"Score={score}")

            if score > termination_threshold:
                inputs = update_inputs(
                    a_inputs=input_A, b_inputs=input_B, inputs=inputs, outdir=outdir
                )
                phi_A, phi_B, preference = get_simulated_feedback(
                    simulation=env,
                    input_A=input_A,
                    input_B=input_B,
                    query_type=query_type,
                    true_reward=true_reward,
                    delta=equiv_size,
                )
                input_features = append(input_features, np.stack([phi_A, phi_B]))
                normals = append(normals, phi_A - phi_B)
                preferences = append(preferences, preference)
                np.save(outdir / "input_features.npy", input_features)
                np.save(outdir / "normals.npy", normals)
                np.save(outdir / "preferences.npy", preferences)

                w_sampler.feed(phi_A, phi_B, [preference])
    except KeyboardInterrupt:
        # Pass through to finally
        logging.warning("\nSaving results, please do not exit again.")
    finally:
        save_reward(query_type, w_sampler, n_reward_samples, outdir, true_delta=equiv_size)
def human(
    criterion: str,
    query_type: str,
    epsilon: float,
    n_reward_samples: int,
    equiv_size: float,
    outdir: Path = Path("questions"),
    continuous: bool = False,
    overwrite: bool = False,
):
    """ Generates a test by eliciting preferences from a human. """
    criterion, query_type, outdir = setup(criterion, query_type, outdir, delta=equiv_size)

    simulation_object = Driver()
    d = simulation_object.num_of_features

    pickle.dump(
        {
            "criterion": criterion,
            "query_type": query_type,
            "epsilon": epsilon,
            "reward_iterations": n_reward_samples,
            "delta": equiv_size,
            "continuous": continuous,
        },
        open(outdir / "flags.pkl", "wb"),
    )

    normals = load(outdir / "normals.npy", overwrite=overwrite)
    preferences = load(outdir / "preferences.npy", overwrite=overwrite)
    inputs = load(outdir / "inputs.npy", overwrite=overwrite)
    input_features = load(outdir / "input_features.npy", overwrite=overwrite)

    w_sampler = Sampler(d)
    if inputs is not None and input_features is not None and preferences is not None:
        for (a_phi, b_phi), preference in zip(input_features, preferences):
            w_sampler.feed(a_phi, b_phi, [preference])

    score = np.inf
    try:
        while score >= epsilon:
            w_samples, delta_samples = w_sampler.sample_given_delta(
                n_reward_samples, query_type, equiv_size
            )

            input_A, input_B, score = run_algo(
                criterion, simulation_object, w_samples, delta_samples, continuous
            )

            if score > epsilon:
                inputs = update_inputs(
                    a_inputs=input_A, b_inputs=input_B, inputs=inputs, outdir=outdir
                )
                phi_A, phi_B, preference = get_feedback(
                    simulation_object, input_A, input_B, query_type
                )
                input_features = append(input_features, np.stack([phi_A, phi_B]))
                normals = append(normals, phi_A - phi_B)
                preferences = append(preferences, preference)
                np.save(outdir / "input_features.npy", input_features)
                np.save(outdir / "normals.npy", normals)
                np.save(outdir / "preferences.npy", preferences)

                w_sampler.feed(phi_A, phi_B, [preference])
    except KeyboardInterrupt:
        # Pass through to finally
        logging.warning("\nSaving results, please do not exit again.")
    finally:
        save_reward(query_type, w_sampler, n_reward_samples, outdir, true_delta=equiv_size)
Пример #15
0
def main(cfg):
    # setting up output directories, and writing to stdout
    make_dirs(cfg.stdout_dir, replace=False)
    if cfg.train:
        run_type = 'train'
    else:
        if 'weight' in cfg.prune_type.lower():
            run_type = 'weight-prune'
        else:
            run_type = 'unit-prune'
    sys.stdout = open(
        '{}/stdout_{}_{}.txt'.format(cfg.stdout_dir, cfg.model_name, run_type),
        'w')
    print(cfg)
    print('\n')
    sys.stdout.flush()

    # if train mode, replace the previous plot and ckpt directories; if in prune mode, use existing directories
    if cfg.plot:
        make_dirs(os.path.join(cfg.plot_dir, cfg.model_name),
                  replace=cfg.train)
    if cfg.save_model:
        make_dirs(os.path.join(cfg.model_dir, cfg.model_name),
                  replace=cfg.train)

    # set random seed
    if cfg.random_seed != 0:
        random_seed = cfg.random_seed
    else:
        random_seed = random.randint(1, 100000)
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)

    # set device as cuda or cpu
    if cfg.use_gpu and torch.cuda.is_available():
        # reproducibility using cuda
        torch.cuda.manual_seed(random_seed)
        cudnn.deterministic = True
        cudnn.benchmark = False
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
        if cfg.use_gpu:
            print('gpu option was to <True>, but no cuda device was found')
            print('\n')

    # datasets and dataloaders
    # normalizing training and validation images to [0, 1] suffices for the purposes of our research objective
    # in training, <drop_last> minibatch in an epoch set to <True> for simplicity in tracking training performance
    dataset_train = MNIST(root='./data/mnist',
                          train=True,
                          download=True,
                          transform=transforms.Compose([transforms.ToTensor()
                                                        ]),
                          target_transform=None)
    dataloader_train = DataLoader(dataset=dataset_train,
                                  batch_size=cfg.batch_size,
                                  shuffle=cfg.shuffle,
                                  num_workers=cfg.num_workers,
                                  pin_memory=True,
                                  drop_last=True)

    dataset_val = MNIST(root='./data/mnist',
                        train=False,
                        download=True,
                        transform=transforms.Compose([transforms.ToTensor()]),
                        target_transform=None)
    dataloader_val = DataLoader(dataset=dataset_val,
                                batch_size=100,
                                shuffle=False,
                                num_workers=cfg.num_workers,
                                pin_memory=True,
                                drop_last=False)

    # automatically compute number of classes
    targets = np.asarray(dataset_train.targets)
    c = np.unique(targets).shape[0]

    # define model
    # weights initialized using Kaiming uniform (He initialization)
    # number of units per hidden layer is passed in as an argument
    net = Net(np.product(cfg.img_size), c, cfg.units).to(device)

    criterion = nn.CrossEntropyLoss()

    if cfg.train:
        # training mode

        if cfg.use_sgd:
            optimizer = optim.SGD(params=net.parameters(),
                                  lr=cfg.lr,
                                  momentum=cfg.momentum,
                                  nesterov=cfg.use_nesterov)
        else:
            optimizer = optim.Adam(params=net.parameters(),
                                   lr=cfg.lr,
                                   betas=(cfg.beta1, cfg.beta2))

        # tracking training and validation stats over epochs
        epochs = []
        train_loss_epochs, val_loss_epochs = [], []
        train_acc_epochs, val_acc_epochs = [], []

        # best model is defined as model with best performing validation loss
        best_loss = float('inf')
        for epoch in range(cfg.epochs):
            # tracking training and validation stats over a given epoch
            train_loss_epoch, val_loss_epoch = [], []
            train_acc_epoch, val_acc_epoch = [], []

            # training set
            for i, (x, y) in enumerate(dataloader_train):
                x, y = x.to(device), y.to(device)

                optimizer.zero_grad()
                logits = net(x)
                loss = criterion(logits, y)
                loss.backward()
                optimizer.step()

                acc = calculate_acc(logits, y)

                append((train_loss_epoch, loss.item()),
                       (train_acc_epoch, acc.item()))

            # validation set
            with torch.no_grad():
                for i, (x, y) in enumerate(dataloader_val):
                    x, y = x.to(device), y.to(device)

                    logits = net(x)
                    loss = criterion(logits, y)

                    acc = calculate_acc(logits, y)

                    append((val_loss_epoch, loss.item()),
                           (val_acc_epoch, acc.item()))

            train_loss_epoch, val_loss_epoch = get_average(
                train_loss_epoch), get_average(val_loss_epoch)
            train_acc_epoch, val_acc_epoch = get_average(
                train_acc_epoch), get_average(val_acc_epoch)

            print('train_epoch{:0=3d}_loss{:.4f}_acc{:.4f}'.format(
                epoch + 1, train_loss_epoch, train_acc_epoch))
            print('valid_epoch{:0=3d}_loss{:.4f}_acc{:.4f}'.format(
                epoch + 1, val_loss_epoch, val_acc_epoch))
            print('\n')
            sys.stdout.flush()

            if cfg.plot:
                append((epochs, epoch + 1),
                       (train_loss_epochs, train_loss_epoch),
                       (val_loss_epochs, val_loss_epoch),
                       (train_acc_epochs, train_acc_epoch),
                       (val_acc_epochs, val_acc_epoch))

                plot_line(epochs, train_loss_epochs, val_loss_epochs,
                          'Epoch Number', 'Loss', cfg)
                plot_line(epochs, train_acc_epochs, val_acc_epochs,
                          'Epoch Number', 'Accuracy', cfg)

            if val_loss_epoch < best_loss:
                best_loss = val_loss_epoch
                print('New best model at epoch {:0=3d} with val_loss {:.4f}'.
                      format(epoch + 1, best_loss))
                print('\n')
                if cfg.save_model:
                    # save model when validation loss improves
                    save_name = '{}_net_epoch{:0=3d}_val_loss{:.4f}'.format(
                        cfg.model_name, epoch + 1, best_loss)
                    torch.save(
                        net.state_dict(),
                        os.path.join(cfg.model_dir, cfg.model_name,
                                     '{}.pth'.format(save_name)))
                    with open(
                            os.path.join(cfg.model_dir, cfg.model_name,
                                         '{}.txt'.format(cfg.model_name)),
                            'w') as file:
                        file.write('{}.pth'.format(save_name))

    else:
        # pruning mode

        # checks on arguments passed in
        for k in cfg.sparsity:
            assert 0 <= k <= 1
        if cfg.use_sparse_mul:
            assert cfg.to_sparse

        # load model
        with open(
                os.path.join(cfg.model_dir, cfg.model_name,
                             '{}.txt'.format(cfg.model_name)), 'r') as file:
            load_name = file.readline()
        net.load_state_dict(
            torch.load(
                os.path.join(cfg.model_dir, cfg.model_name,
                             '{}'.format(load_name))))
        net.eval()

        # select pruning approach to use
        if 'weight' in cfg.prune_type.lower():
            prune = weight_prune
        else:
            prune = unit_prune

        sparsities = []
        val_loss_sparse, val_acc_sparse = [], []
        time_sparsities = []
        for k in cfg.sparsity:
            val_loss_k, val_acc_k = [], []
            time_k = []

            # copy network so that the sparsity changes are not additive for each k
            net_sparse = copy.deepcopy(net)

            pruned_weights = []
            # prune model, except for the last layer
            for (i, p) in enumerate(net_sparse.parameters()):
                if i < len(cfg.units):
                    original_weights = copy.deepcopy(p.data)
                    if cfg.plot:
                        # plot magnitude of original weights (for comparison to post-pruned weights)
                        plot_hist([
                            torch.abs(
                                original_weights.flatten()).cpu().numpy()
                        ], ['b'], cfg.prune_type, i + 1, k,
                                  'Non-Pruned Weight Magnitudes', 'Counts',
                                  cfg)
                    prune(p.data, k)
                    if cfg.plot:
                        # plot original magnitudes of pruned weights, and magnitudes of remaining weights, separately
                        pruned_weights_non_zero = torch.abs(
                            original_weights.flatten()[p.data.flatten() != 0])
                        pruned_weights_zeroed = torch.abs(
                            original_weights.flatten()[p.data.flatten() == 0])
                        plot_hist([
                            pruned_weights_non_zero.cpu().numpy(),
                            pruned_weights_zeroed.cpu().numpy()
                        ], ['g', 'r'], cfg.prune_type, i + 1, k,
                                  'Weight Magnitudes', 'Counts', cfg)
                        plot_hist([pruned_weights_non_zero.cpu().numpy()],
                                  ['k'], cfg.prune_type, i + 1, k,
                                  'Surviving Weight Magnitudes', 'Counts', cfg)
                if cfg.to_sparse and i < len(cfg.units):
                    pruned_weights.append(p.data.to_sparse())
                else:
                    pruned_weights.append(p.data)

            with torch.no_grad():
                for i, (x, y) in enumerate(dataloader_val):
                    x, y = x.to(device), y.to(device)

                    start = time.time()
                    logits = forward(x, pruned_weights, cfg.use_sparse_mul)
                    end = time.time()
                    loss = criterion(logits, y)

                    acc = calculate_acc(logits, y)

                    append((val_loss_k, loss.item()), (val_acc_k, acc.item()),
                           (time_k, end - start))

            val_loss_k, val_acc_k, time_k = get_average(
                val_loss_k), get_average(val_acc_k), get_average(time_k)

            print('valid_{}_k{:.2f}_loss{:.4f}_acc{:.4f}'.format(
                run_type, k, val_loss_k, val_acc_k))
            print('valid_{}_k{:.2f}_time/minibatch{:.6f}'.format(
                run_type, k, time_k))
            print('\n')
            sys.stdout.flush()

            if cfg.plot:
                append((sparsities, k), (val_loss_sparse, val_loss_k),
                       (val_acc_sparse, val_acc_k), (time_sparsities, time_k))

                plot_line(sparsities, [], val_loss_sparse,
                          'Sparsity {} Prune'.format(cfg.prune_type), 'Loss',
                          cfg)
                plot_line(sparsities, [], val_acc_sparse,
                          'Sparsity {} Prune'.format(cfg.prune_type),
                          'Accuracy', cfg)
                plot_line(sparsities, [], time_sparsities,
                          'Sparsity {} Prune'.format(cfg.prune_type), 'Time',
                          cfg)

            if cfg.save_model:
                torch.save(
                    net_sparse.state_dict(),
                    os.path.join(
                        cfg.model_dir, cfg.model_name,
                        '{}_sparse_net_{}_val_loss{:.4f}.pth'.format(
                            cfg.model_name, run_type, val_loss_k)))
def collect(
    outdir: Path,
    n_rewards: int,
    test_reward_path: Optional[Path] = None,
    std: Optional[float] = None,
    mean_reward_path: Optional[Path] = None,
    normals_paths: Optional[List[Path]] = None,
    preferences_paths: Optional[List[Path]] = None,
    use_random: bool = False,
    use_plausible: bool = False,
    skip_human: bool = False,
    overwrite: bool = False,
) -> None:
    """Collects ground truth labels for the optimal trajectories of some reward functions.

    Args:
        outdir (Path): Directory to write output to
        n_rewards (int): Number of rewards to generate or process
        test_reward_path (Optional[Path], optional): Path to nupmy array of reward weights to test. Defaults to None.
        std (Optional[float], optional): Standard deviation of normal distribution to draw test reward weigths from. Defaults to None.
        mean_reward_path (Optional[Path], optional): Path to numpy array specifying mean reward weights to sample around. Defaults to None.
        overwrite (bool, optional): Overwrite output? Defaults to False.

    Raises:
        ValueError: Raised if neither test_reward_path or both std and mean_reward_path are specified. The test rewards need to come from somewhere.
    """
    outdir = Path(outdir)
    outdir.mkdir(parents=True, exist_ok=True)

    out_rewards = load(outdir, "test_rewards.npy", overwrite=overwrite)
    new_rewards_index = out_rewards.shape[0] if out_rewards is not None else 0
    num_new_rewards = n_rewards - new_rewards_index

    env = Driver()

    if num_new_rewards > 0:
        if test_reward_path is not None:
            rewards = np.load(
                test_reward_path)[new_rewards_index:num_new_rewards]
        elif mean_reward_path is not None and std is not None:
            mean_reward = np.load(mean_reward_path)
            rewards = default_rng().normal(loc=mean_reward,
                                           scale=std,
                                           size=(num_new_rewards,
                                                 *mean_reward.shape))
        elif normals_paths is not None and preferences_paths is not None and std is not None:
            # NOTE(joschnei): This turned out not to work, because the random baseline is poisoning the well
            normals = None
            for normals_path, preferences_path in zip(normals_paths,
                                                      preferences_paths):
                single_normals = np.load(normals_path)
                single_preferences = np.load(preferences_path)
                single_normals = (single_normals.T * single_preferences).T
                normals = append(normals, single_normals, flat=True)
            # TODO(joschnei): These can all be loaded in from flags.pkl, but I'm too lazy for that.
            mean_reward = make_mode_reward(
                query_type="strict",
                true_delta=1.1,
                w_sampler=Sampler(env.num_of_features),
                n_reward_samples=100,
            )
            assert np.all(np.isfinite(mean_reward))
            rewards = default_rng().normal(loc=mean_reward,
                                           scale=std,
                                           size=(num_new_rewards,
                                                 *mean_reward.shape))
            assert np.all(np.isfinite(rewards))
        elif use_random:
            rewards = default_rng().normal(loc=0,
                                           scale=1,
                                           size=(num_new_rewards,
                                                 env.num_of_features))
            rewards = rewards / np.linalg.norm(rewards)
        elif use_plausible:
            # Generate uniform rewards with plausible weights i.e. ones with the right sign
            rewards = default_rng().normal(loc=0,
                                           scale=1,
                                           size=(num_new_rewards,
                                                 env.num_of_features))
            rewards = rewards / np.linalg.norm(rewards)

            # See models.py for reward feature details.
            rewards[:, 0] = np.abs(rewards[:, 0])
            rewards[:, 1] = -np.abs(rewards[:, 1])
            rewards[:, 2] = np.abs(rewards[:, 2])
            rewards[:, 3] = -np.abs(rewards[:, 3])
        else:
            raise ValueError(
                "You must either supply a path to the test rewards, or a mean reward and "
                "std from which to sample the test rewards.")
        out_rewards = append(out_rewards, rewards, flat=True)
    else:
        assert out_rewards is not None

    assert np.all(np.isfinite(out_rewards))
    np.save(open(outdir / "test_rewards.npy", "wb"), out_rewards)

    paths = load(outdir, "optimal_paths.npy", overwrite=overwrite)
    new_paths_index = paths.shape[0] if paths is not None else 0
    num_new_paths = n_rewards - new_paths_index

    if num_new_paths > 0:
        new_paths = np.array(
            Parallel(n_jobs=-2)(delayed(make_opt_traj)(reward)
                                for reward in out_rewards[new_paths_index:]))
        paths = append(paths, new_paths, flat=True)
    else:
        assert paths is not None
    np.save(open(outdir / "optimal_paths.npy", "wb"), np.array(paths))

    gt_alignment = load(outdir, "alignment.npy", overwrite=overwrite)
    new_gt_index = gt_alignment.size if gt_alignment is not None else 0

    if skip_human:
        exit()

    for path in paths[new_gt_index:]:
        env.set_ctrl(path)
        env.watch(1)

        alignment = input("Aligned (y/n):").lower()
        while alignment not in ["y", "n"]:
            alignment = input("Aligned (y/n):").lower()
        gt_alignment = append(gt_alignment, alignment == "y")

    np.save(open(outdir / "alignment.npy", "wb"), gt_alignment)