Пример #1
0
def get_next_gate_tree_by_log_loss(unused_gate_trees, data_input, params, model=None):
    if model:
        losses = []
        for gate_tree in unused_gate_trees:
            dummy_model_state = deepcopy(model.state_dict())
            dummy_model = DepthOneModel(model.get_gate_tree(), params['model_params'])
            dummy_model.load_state_dict(dummy_model_state)

            dummy_model.add_node(gate_tree)
            performance_tracker = run_train_model(dummy_model, params['train_params'], data_input)
            losses.append(dummy_model(data_input.x_tr, data_input.y_tr)['log_loss'].cpu().detach().numpy())
        losses = np.array(losses)
        best_gate_idx = np.argmin(losses[~np.isnan(losses)])
    else:
        losses = []
        for gate_tree in unused_gate_trees:
            model = DepthOneModel([gate_tree], params['model_params'])
            performance_tracker = run_train_model(model, params['train_params'], data_input)
            losses.append(model(data_input.x_tr, data_input.y_tr)['log_loss'].cpu().detach().numpy())

        losses = np.array(losses)
        best_gate_idx = np.argmin(losses[~np.isnan(losses)])
    best_gate = unused_gate_trees[best_gate_idx]
    del unused_gate_trees[best_gate_idx]
    return best_gate, unused_gate_trees
Пример #2
0
def get_next_best_gate(remaining_gates, data_input, params, model):
    losses = []
    trackers = []
    for gate in remaining_gates:
        dummy_model_state = deepcopy(model.state_dict())
        init_gates = rectangularize_gates(model)
        dummy_model = DepthOneModel(init_gates, params['model_params'])
        dummy_model.load_state_dict(dummy_model_state)
        dummy_model.add_node(gate)
        trackers.append(run_train_model(
            dummy_model, params['train_params'], data_input
        ))
        losses.append(
            dummy_model(
                data_input.x_tr, data_input.y_tr
            )['loss'].cpu().detach().numpy()
        )
    losses = np.array(losses)
    best_gate_idx = np.argmin(losses[~np.isnan(losses)])

    best_gate = remaining_gates[best_gate_idx]
    remaining_gates = [
        gate for g, gate in enumerate(remaining_gates)
        if not g == best_gate_idx
    ]
    best_tracker = trackers[best_gate_idx]
    return best_gate, remaining_gates, best_tracker